mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix backprop bug
This commit is contained in:
parent
cd216f50b6
commit
3d9ddc2016
@ -220,7 +220,7 @@ def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|||||||
class ExpScaleSwishFunction(torch.autograd.Function):
|
class ExpScaleSwishFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||||
ctx.save_for_backward(x, scale)
|
ctx.save_for_backward(x.detach(), scale.detach())
|
||||||
ctx.speed = speed
|
ctx.speed = speed
|
||||||
return _exp_scale_swish(x, scale, speed)
|
return _exp_scale_swish(x, scale, speed)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user