Remove loss factor from decorr_loss_scale

This commit is contained in:
Daniel Povey 2022-06-08 20:19:17 +08:00
parent 8e56445c70
commit b9a476c7bb

View File

@ -745,7 +745,6 @@ class DecorrelateFunction(torch.autograd.Function):
cov = old_cov * ctx.beta + torch.matmul(x.t(), x) * (1-ctx.beta)
inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
print("Decorrelate: norm_cov = ", norm_cov)
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1
if random.random() < 0.01:
@ -759,7 +758,7 @@ class DecorrelateFunction(torch.autograd.Function):
# `loss ** 0.5` times the magnitude of the original grad.
x_grad_new_scale = (x_grad_new ** 2).sum(dim=1)
x_grad_old_scale = (x_grad ** 2).sum(dim=1)
decorr_loss_scale = ctx.scale * loss.detach() ** 0.5
decorr_loss_scale = ctx.scale
scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5
x_grad_new = x_grad_new * scale.unsqueeze(-1)