diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index cd7faba8a..8007d7f32 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -812,7 +812,7 @@ class DecorrelateFunction(torch.autograd.Function): # the loss starts getting quite small (less than 1), we start using # smaller derivatives. decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0) - scale = decorr_loss_scale * (x_grad_old_sqnorm / (decorr_x_grad_sqnorm + 1.0e-10)) ** 0.5 + scale = decorr_loss_scale * (x_grad_old_sqnorm / (decorr_x_grad_sqnorm + 1.0e-20)) ** 0.5 decorr_x_grad = decorr_x_grad * scale.unsqueeze(-1) x_grad = x_grad + decorr_x_grad