mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use 4 times the normal grad_scale for BasicNorm if max_rms violated.
This commit is contained in:
parent
577c3ad390
commit
008dbaf745
@ -1021,8 +1021,8 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
|
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
|
||||||
uncentered_var = (x ** 2).mean(dim=mean_dims)
|
uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True)
|
||||||
mean = x.mean(dim=mean_dims)
|
mean = x.mean(dim=mean_dims, keepdim=True)
|
||||||
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
||||||
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
||||||
|
|
||||||
@ -1039,6 +1039,9 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
loss_grad = x.grad
|
loss_grad = x.grad
|
||||||
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
||||||
|
|
||||||
|
# use 4 times the normal grad_scale in dimensions where the max_rms constraint is violated;
|
||||||
|
# we sometimes have trouble enforcing this one.
|
||||||
|
grad_scale = grad_scale * (1.0 + 3.0 * (rms > max_rms))
|
||||||
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||||
|
|
||||||
x_grad_float = x_grad.to(torch.float32)
|
x_grad_float = x_grad.to(torch.float32)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user