Use 4 times the normal grad_scale for BasicNorm if max_rms violated.

This commit is contained in:
Daniel Povey 2022-12-31 23:42:38 +08:00
parent 577c3ad390
commit 008dbaf745

View File

@ -1021,8 +1021,8 @@ class BalancerFunction(torch.autograd.Function):
x = x.detach()
x.requires_grad = True
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
uncentered_var = (x ** 2).mean(dim=mean_dims)
mean = x.mean(dim=mean_dims)
uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True)
mean = x.mean(dim=mean_dims, keepdim=True)
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
@ -1039,6 +1039,9 @@ class BalancerFunction(torch.autograd.Function):
loss_grad = x.grad
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
# use 4 times the normal grad_scale in dimensions where the max_rms constraint is violated;
# we sometimes have trouble enforcing this one.
grad_scale = grad_scale * (1.0 + 3.0 * (rms > max_rms))
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
x_grad_float = x_grad.to(torch.float32)