Revert some changes to Balancer.

This commit is contained in:
Daniel Povey 2023-01-01 23:02:52 +08:00
parent e52bfb7219
commit a2227a07fc

View File

@ -1023,17 +1023,14 @@ class BalancerFunction(torch.autograd.Function):
# part of loss that relates to mean / stddev
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
if min_rms <= 0:
r_loss_min = 0.0
else:
r_loss_min = (min_rms / rms).log().relu()
r_loss_max = (rms / max_rms).log().relu()
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
# m_loss are violated we fix the RMS loss first.
loss = (m_loss + r_loss_min + 100.0 * r_loss_max).sum()
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
r_loss = (rms_clamped / rms).log().abs()
loss.backward()
loss = (m_loss + r_loss)
loss.backward(gradient=torch.ones_like(loss))
loss_grad = x.grad
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)