mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert some changes to Balancer.
This commit is contained in:
parent
e52bfb7219
commit
a2227a07fc
@ -1023,17 +1023,14 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
# part of loss that relates to mean / stddev
|
# part of loss that relates to mean / stddev
|
||||||
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
||||||
|
|
||||||
if min_rms <= 0:
|
|
||||||
r_loss_min = 0.0
|
|
||||||
else:
|
|
||||||
r_loss_min = (min_rms / rms).log().relu()
|
|
||||||
r_loss_max = (rms / max_rms).log().relu()
|
|
||||||
|
|
||||||
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
|
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
|
||||||
# m_loss are violated we fix the RMS loss first.
|
# m_loss are violated we fix the RMS loss first.
|
||||||
loss = (m_loss + r_loss_min + 100.0 * r_loss_max).sum()
|
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
||||||
|
r_loss = (rms_clamped / rms).log().abs()
|
||||||
|
|
||||||
loss.backward()
|
loss = (m_loss + r_loss)
|
||||||
|
|
||||||
|
loss.backward(gradient=torch.ones_like(loss))
|
||||||
loss_grad = x.grad
|
loss_grad = x.grad
|
||||||
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user