diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f370f3386..706c3c1c7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -449,7 +449,7 @@ class BasicNormFunction(torch.autograd.Function): for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 - ans = x * scales + ans = x * scales - bias ctx.save_for_backward(ans if store_output_for_backprop else x, scales, bias, eps) return ans @@ -471,7 +471,7 @@ class BasicNormFunction(torch.autograd.Function): with torch.enable_grad(): # recompute scales from x, bias and eps. scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5 - ans = x * scales + ans = x * scales - bias ans.backward(gradient=ans_grad) return x.grad, bias.grad.flatten(), eps.grad, None, None