mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix backprop bug
This commit is contained in:
parent
cd216f50b6
commit
3d9ddc2016
@ -220,7 +220,7 @@ def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||
class ExpScaleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||
ctx.save_for_backward(x, scale)
|
||||
ctx.save_for_backward(x.detach(), scale.detach())
|
||||
ctx.speed = speed
|
||||
return _exp_scale_swish(x, scale, speed)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user