Scale by grad norm

This commit is contained in:
Daniel Povey 2022-06-10 18:34:42 +08:00
parent 6a47bf1178
commit 6ed181595b

View File

@ -796,22 +796,16 @@ class DecorrelateFunction(torch.autograd.Function):
# to have magnitudes proportional to the norm of the gradient on that
# frame; the goal is to exclude "don't-care" frames such as padding frames from
# the computation.
#x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
x_grad_sqrt_norm = (x_grad ** 2).sum(dim=1) ** 0.25
x_grad_sqrt_norm /= x_grad_sqrt_norm.mean()
x_grad_sqrt_norm_is_inf = (x_grad_sqrt_norm - x_grad_sqrt_norm != 0)
x_grad_sqrt_norm.masked_fill_(x_grad_sqrt_norm_is_inf, 1.0)
with torch.enable_grad():
#x_sqnorm = (x ** 2).sum(dim=1)
# scale up frames with larger grads.
x_scaled = x * x_grad_sqrt_norm.unsqueeze(-1)
#x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov
#x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales
#x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0)
# if grads are inf, use equal scales for frames (can happen due to GradScaler, in half
# precision)
#x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel())
#x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5
#scaled_x = x * x_factor.unsqueeze(-1)
cov = _update_cov_stats(old_cov, x, ctx.beta)
cov = _update_cov_stats(old_cov, x_scaled, ctx.beta)
assert old_cov.dtype != torch.float16
old_cov[:] = cov # update the stats outside! This is not really
# how backprop is supposed to work, but this input