diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 7ac7b0a00..8d4fbc46c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1023,17 +1023,14 @@ class BalancerFunction(torch.autograd.Function): # part of loss that relates to mean / stddev m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - if min_rms <= 0: - r_loss_min = 0.0 - else: - r_loss_min = (min_rms / rms).log().relu() - r_loss_max = (rms / max_rms).log().relu() - # put a much larger scale on the RMS-max-limit loss, so that if both it and the # m_loss are violated we fix the RMS loss first. - loss = (m_loss + r_loss_min + 100.0 * r_loss_max).sum() + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() - loss.backward() + loss = (m_loss + r_loss) + + loss.backward(gradient=torch.ones_like(loss)) loss_grad = x.grad loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)