Fix backprop bug

This commit is contained in:
Daniel Povey 2022-03-04 12:29:44 +08:00
parent cd216f50b6
commit 3d9ddc2016

View File

@ -220,7 +220,7 @@ def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor:
class ExpScaleSwishFunction(torch.autograd.Function): class ExpScaleSwishFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x, scale) ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed ctx.speed = speed
return _exp_scale_swish(x, scale, speed) return _exp_scale_swish(x, scale, speed)