mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Remove loss factor from decorr_loss_scale
This commit is contained in:
parent
8e56445c70
commit
b9a476c7bb
@ -745,7 +745,6 @@ class DecorrelateFunction(torch.autograd.Function):
|
||||
cov = old_cov * ctx.beta + torch.matmul(x.t(), x) * (1-ctx.beta)
|
||||
inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5
|
||||
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||
print("Decorrelate: norm_cov = ", norm_cov)
|
||||
|
||||
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1
|
||||
if random.random() < 0.01:
|
||||
@ -759,7 +758,7 @@ class DecorrelateFunction(torch.autograd.Function):
|
||||
# `loss ** 0.5` times the magnitude of the original grad.
|
||||
x_grad_new_scale = (x_grad_new ** 2).sum(dim=1)
|
||||
x_grad_old_scale = (x_grad ** 2).sum(dim=1)
|
||||
decorr_loss_scale = ctx.scale * loss.detach() ** 0.5
|
||||
decorr_loss_scale = ctx.scale
|
||||
scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5
|
||||
x_grad_new = x_grad_new * scale.unsqueeze(-1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user