Make RMS loss dominate mean loss in Balancer if both are active; remove the 4x scale introduced in 814.

This commit is contained in:
Daniel Povey 2023-01-01 00:09:14 +08:00
parent a2815ea0df
commit 907d28ca2a

View File

@ -1033,15 +1033,14 @@ class BalancerFunction(torch.autograd.Function):
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
r_loss = (rms_clamped / rms).log().abs()
loss = (m_loss + r_loss).sum()
# put a much larger scale on the RMS loss, so that if both are violated we fix
# the RMS loss first.
loss = (m_loss + 100.0 * r_loss).sum()
loss.backward()
loss_grad = x.grad
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
# use 4 times the normal grad_scale in dimensions where the max_rms constraint is violated;
# we sometimes have trouble enforcing this one.
grad_scale = grad_scale * (1.0 + 3.0 * (rms > max_rms))
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
x_grad_float = x_grad.to(torch.float32)