mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add clip_grad
This commit is contained in:
parent
3351402875
commit
bcc9971ebe
@ -287,44 +287,23 @@ def softmax(x: Tensor,
|
||||
return SoftmaxFunction.apply(x, dim)
|
||||
|
||||
|
||||
class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
class ClipGradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
coeffs: Tensor,
|
||||
direction: Tensor,
|
||||
channel_dim: int,
|
||||
grad_scale: float) -> Tensor:
|
||||
ctx.channel_dim = channel_dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.save_for_backward(x.detach(),
|
||||
coeffs.detach(),
|
||||
direction.detach())
|
||||
limit: float):
|
||||
ctx.limit = limit
|
||||
return x
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad, *args):
|
||||
with torch.enable_grad():
|
||||
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
||||
x_orig.requires_grad = True
|
||||
num_channels = x_orig.shape[ctx.channel_dim]
|
||||
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
||||
new_direction.requires_grad = False
|
||||
x = x - x.mean(dim=0)
|
||||
x_var = (x ** 2).mean()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual ** 2).mean()
|
||||
# `variance_proportion` is the proportion of the variance accounted for
|
||||
# by the top eigen-direction. This is to be minimized.
|
||||
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
||||
variance_proportion.backward()
|
||||
x_orig_grad = x_orig.grad
|
||||
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
|
||||
return x_grad + x_extra_grad.detach(), None, None, None, None
|
||||
return x_grad.clamp(-ctx.limit, ctx.limit), None
|
||||
|
||||
|
||||
def clip_grad(x: Tensor, limit: float):
|
||||
return ClipGradFunction.apply(x, limit)
|
||||
|
||||
|
||||
|
||||
class BiasNormFunction(torch.autograd.Function):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user