diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index cf2f7f3aa..174ccf39e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -728,7 +728,7 @@ class LimitParamValue(torch.autograd.Function): x, = ctx.saved_tensors # where x < ctx.min, ensure all grads are negative (this will tend to make # x more positive). - x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)