diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 7550c1b0e..40771f95a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -820,11 +820,18 @@ class DecorrelateFunction(torch.autograd.Function): x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) x_sqnorm = (x.detach() ** 2).sum(dim=1) x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x in sum for cov + x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales + x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0) + # if grads are inf, use equal scales for frames (can happen due to GradScaler, in half + # precision) + x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel()) + x_factor = (x_desired_sqscale / (x_sqnorm + ctx.eps)) ** 0.5 with torch.enable_grad(): scaled_x = x * x_factor.unsqueeze(-1) cov = _update_cov_stats(old_cov, scaled_x, ctx.beta) + assert old_cov.dtype != torch.float16 old_cov[:] = cov # update the stats outside! This is not really # how backprop is supposed to work, but this input # is not differentiable.. @@ -832,6 +839,7 @@ class DecorrelateFunction(torch.autograd.Function): if random.random() < 0.01: logging.info(f"Decorrelate: loss = {loss}") + loss.backward() decorr_x_grad = x.grad