diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 8a81e6c3f..9373e6c44 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1021,8 +1021,8 @@ class BalancerFunction(torch.autograd.Function): x = x.detach() x.requires_grad = True mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims) - mean = x.mean(dim=mean_dims) + uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() rms = uncentered_var.clamp(min=1.0e-20).sqrt() @@ -1039,6 +1039,9 @@ class BalancerFunction(torch.autograd.Function): loss_grad = x.grad loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + # use 4 times the normal grad_scale in dimensions where the max_rms constraint is violated; + # we sometimes have trouble enforcing this one. + grad_scale = grad_scale * (1.0 + 3.0 * (rms > max_rms)) loss_grad = loss_grad * (grad_scale / loss_grad_rms) x_grad_float = x_grad.to(torch.float32)