mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add clip_grad
This commit is contained in:
parent
3351402875
commit
bcc9971ebe
@ -287,44 +287,23 @@ def softmax(x: Tensor,
|
|||||||
return SoftmaxFunction.apply(x, dim)
|
return SoftmaxFunction.apply(x, dim)
|
||||||
|
|
||||||
|
|
||||||
class MaxEigLimiterFunction(torch.autograd.Function):
|
class ClipGradFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
ctx,
|
ctx,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
coeffs: Tensor,
|
limit: float):
|
||||||
direction: Tensor,
|
ctx.limit = limit
|
||||||
channel_dim: int,
|
|
||||||
grad_scale: float) -> Tensor:
|
|
||||||
ctx.channel_dim = channel_dim
|
|
||||||
ctx.grad_scale = grad_scale
|
|
||||||
ctx.save_for_backward(x.detach(),
|
|
||||||
coeffs.detach(),
|
|
||||||
direction.detach())
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, x_grad, *args):
|
def backward(ctx, x_grad, *args):
|
||||||
with torch.enable_grad():
|
return x_grad.clamp(-ctx.limit, ctx.limit), None
|
||||||
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
|
||||||
x_orig.requires_grad = True
|
|
||||||
num_channels = x_orig.shape[ctx.channel_dim]
|
|
||||||
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
|
||||||
new_direction.requires_grad = False
|
|
||||||
x = x - x.mean(dim=0)
|
|
||||||
x_var = (x ** 2).mean()
|
|
||||||
x_residual = x - coeffs * new_direction
|
|
||||||
x_residual_var = (x_residual ** 2).mean()
|
|
||||||
# `variance_proportion` is the proportion of the variance accounted for
|
|
||||||
# by the top eigen-direction. This is to be minimized.
|
|
||||||
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
|
||||||
variance_proportion.backward()
|
|
||||||
x_orig_grad = x_orig.grad
|
|
||||||
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
|
|
||||||
return x_grad + x_extra_grad.detach(), None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
|
def clip_grad(x: Tensor, limit: float):
|
||||||
|
return ClipGradFunction.apply(x, limit)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BiasNormFunction(torch.autograd.Function):
|
class BiasNormFunction(torch.autograd.Function):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user