diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 9bc62ba9a..24477f91e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1031,11 +1031,15 @@ class BalancerFunction(torch.autograd.Function): m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() + if min_rms <= 0: + r_loss_min = 0.0 + else: + r_loss_min = (min_rms / rms).log().relu().sum() + r_loss_max = (rms / max_rms).log().relu().sum() - # put a much larger scale on the RMS loss, so that if both are violated we fix - # the RMS loss first. - loss = (m_loss + 100.0 * r_loss).sum() + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + loss = (m_loss + r_loss_min + 100.0 * r_loss_max).sum() loss.backward() loss_grad = x.grad