diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 24477f91e..dcf814129 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1030,12 +1030,11 @@ class BalancerFunction(torch.autograd.Function): # part of loss that relates to mean / stddev m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - rms_clamped = rms.clamp(min=min_rms, max=max_rms) if min_rms <= 0: r_loss_min = 0.0 else: - r_loss_min = (min_rms / rms).log().relu().sum() - r_loss_max = (rms / max_rms).log().relu().sum() + r_loss_min = (min_rms / rms).log().relu() + r_loss_max = (rms / max_rms).log().relu() # put a much larger scale on the RMS-max-limit loss, so that if both it and the # m_loss are violated we fix the RMS loss first.