Fix an issue with scaling of grad.

This commit is contained in:
Daniel Povey 2022-10-15 15:36:55 +08:00
parent fcbb960da1
commit 593a6e946d

View File

@ -859,7 +859,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
(metric - ctx.whitening_limit).relu().backward() (metric - ctx.whitening_limit).relu().backward()
penalty_grad = x_detached.grad penalty_grad = x_detached.grad
scale = ctx.grad_scale * (x.to(torch.float32).norm() / scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
(penalty_grad.norm() + 1.0e-20)) (penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None return x_grad + penalty_grad.to(x_grad.dtype), None, None, None