diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index fa8f629f7..d6840afab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -745,7 +745,6 @@ class DecorrelateFunction(torch.autograd.Function): cov = old_cov * ctx.beta + torch.matmul(x.t(), x) * (1-ctx.beta) inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5 norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) - print("Decorrelate: norm_cov = ", norm_cov) loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1 if random.random() < 0.01: @@ -759,7 +758,7 @@ class DecorrelateFunction(torch.autograd.Function): # `loss ** 0.5` times the magnitude of the original grad. x_grad_new_scale = (x_grad_new ** 2).sum(dim=1) x_grad_old_scale = (x_grad ** 2).sum(dim=1) - decorr_loss_scale = ctx.scale * loss.detach() ** 0.5 + decorr_loss_scale = ctx.scale scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5 x_grad_new = x_grad_new * scale.unsqueeze(-1)