From 8db0636f1d95f683bb40509ae96b82987ca879cb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 1 Jan 2023 00:38:07 +0800 Subject: [PATCH] Fix to Balancer to treat max-rms and min-rms losses separately, only max-rms loss scaled up --- .../ASR/pruned_transducer_stateless7/scaling.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 9bc62ba9a..24477f91e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1031,11 +1031,15 @@ class BalancerFunction(torch.autograd.Function): m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() + if min_rms <= 0: + r_loss_min = 0.0 + else: + r_loss_min = (min_rms / rms).log().relu().sum() + r_loss_max = (rms / max_rms).log().relu().sum() - # put a much larger scale on the RMS loss, so that if both are violated we fix - # the RMS loss first. - loss = (m_loss + 100.0 * r_loss).sum() + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + loss = (m_loss + r_loss_min + 100.0 * r_loss_max).sum() loss.backward() loss_grad = x.grad