mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make RMS loss dominate mean loss in Balancer if both are active; remove the 4x scale introduced in 814.
This commit is contained in:
parent
a2815ea0df
commit
907d28ca2a
@ -1033,15 +1033,14 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
||||||
r_loss = (rms_clamped / rms).log().abs()
|
r_loss = (rms_clamped / rms).log().abs()
|
||||||
|
|
||||||
loss = (m_loss + r_loss).sum()
|
# put a much larger scale on the RMS loss, so that if both are violated we fix
|
||||||
|
# the RMS loss first.
|
||||||
|
loss = (m_loss + 100.0 * r_loss).sum()
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
loss_grad = x.grad
|
loss_grad = x.grad
|
||||||
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
||||||
|
|
||||||
# use 4 times the normal grad_scale in dimensions where the max_rms constraint is violated;
|
|
||||||
# we sometimes have trouble enforcing this one.
|
|
||||||
grad_scale = grad_scale * (1.0 + 3.0 * (rms > max_rms))
|
|
||||||
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||||
|
|
||||||
x_grad_float = x_grad.to(torch.float32)
|
x_grad_float = x_grad.to(torch.float32)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user