Fix to Balancer to treat max-rms and min-rms losses separately, only max-rms loss scaled up

This commit is contained in:
Daniel Povey 2023-01-01 00:38:07 +08:00
parent 907d28ca2a
commit 8db0636f1d

View File

@ -1031,11 +1031,15 @@ class BalancerFunction(torch.autograd.Function):
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
r_loss = (rms_clamped / rms).log().abs()
if min_rms <= 0:
r_loss_min = 0.0
else:
r_loss_min = (min_rms / rms).log().relu().sum()
r_loss_max = (rms / max_rms).log().relu().sum()
# put a much larger scale on the RMS loss, so that if both are violated we fix
# the RMS loss first.
loss = (m_loss + 100.0 * r_loss).sum()
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
# m_loss are violated we fix the RMS loss first.
loss = (m_loss + r_loss_min + 100.0 * r_loss_max).sum()
loss.backward()
loss_grad = x.grad