diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 9373e6c44..9bc62ba9a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1033,15 +1033,14 @@ class BalancerFunction(torch.autograd.Function): rms_clamped = rms.clamp(min=min_rms, max=max_rms) r_loss = (rms_clamped / rms).log().abs() - loss = (m_loss + r_loss).sum() + # put a much larger scale on the RMS loss, so that if both are violated we fix + # the RMS loss first. + loss = (m_loss + 100.0 * r_loss).sum() loss.backward() loss_grad = x.grad loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) - # use 4 times the normal grad_scale in dimensions where the max_rms constraint is violated; - # we sometimes have trouble enforcing this one. - grad_scale = grad_scale * (1.0 + 3.0 * (rms > max_rms)) loss_grad = loss_grad * (grad_scale / loss_grad_rms) x_grad_float = x_grad.to(torch.float32)