From 6ed181595b79f15ad4a965f56a5bfe35f6d8c44f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Jun 2022 18:34:42 +0800 Subject: [PATCH] Scale by grad norm --- .../pruned_transducer_stateless2/scaling.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index a3a88d83a..c7c309e21 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -796,22 +796,16 @@ class DecorrelateFunction(torch.autograd.Function): # to have magnitudes proportional to the norm of the gradient on that # frame; the goal is to exclude "don't-care" frames such as padding frames from # the computation. - #x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) + x_grad_sqrt_norm = (x_grad ** 2).sum(dim=1) ** 0.25 + x_grad_sqrt_norm /= x_grad_sqrt_norm.mean() + x_grad_sqrt_norm_is_inf = (x_grad_sqrt_norm - x_grad_sqrt_norm != 0) + x_grad_sqrt_norm.masked_fill_(x_grad_sqrt_norm_is_inf, 1.0) with torch.enable_grad(): - #x_sqnorm = (x ** 2).sum(dim=1) + # scale up frames with larger grads. + x_scaled = x * x_grad_sqrt_norm.unsqueeze(-1) - #x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov - #x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales - #x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0) - # if grads are inf, use equal scales for frames (can happen due to GradScaler, in half - # precision) - #x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel()) - - #x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5 - - #scaled_x = x * x_factor.unsqueeze(-1) - cov = _update_cov_stats(old_cov, x, ctx.beta) + cov = _update_cov_stats(old_cov, x_scaled, ctx.beta) assert old_cov.dtype != torch.float16 old_cov[:] = cov # update the stats outside! This is not really # how backprop is supposed to work, but this input