Fix bugs in how max_rms/min_rms constraint were applied, which had the effect of making min_rms dominate over mean.

This commit is contained in:
Daniel Povey 2023-01-01 13:05:41 +08:00
parent 8db0636f1d
commit 1797d0ec6d

View File

@ -1030,12 +1030,11 @@ class BalancerFunction(torch.autograd.Function):
# part of loss that relates to mean / stddev
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
if min_rms <= 0:
r_loss_min = 0.0
else:
r_loss_min = (min_rms / rms).log().relu().sum()
r_loss_max = (rms / max_rms).log().relu().sum()
r_loss_min = (min_rms / rms).log().relu()
r_loss_max = (rms / max_rms).log().relu()
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
# m_loss are violated we fix the RMS loss first.