Add clip_grad

This commit is contained in:
Daniel Povey 2023-05-23 14:00:56 +08:00
parent 3351402875
commit bcc9971ebe

View File

@ -287,44 +287,23 @@ def softmax(x: Tensor,
return SoftmaxFunction.apply(x, dim) return SoftmaxFunction.apply(x, dim)
class MaxEigLimiterFunction(torch.autograd.Function): class ClipGradFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
x: Tensor, x: Tensor,
coeffs: Tensor, limit: float):
direction: Tensor, ctx.limit = limit
channel_dim: int,
grad_scale: float) -> Tensor:
ctx.channel_dim = channel_dim
ctx.grad_scale = grad_scale
ctx.save_for_backward(x.detach(),
coeffs.detach(),
direction.detach())
return x return x
@staticmethod @staticmethod
def backward(ctx, x_grad, *args): def backward(ctx, x_grad, *args):
with torch.enable_grad(): return x_grad.clamp(-ctx.limit, ctx.limit), None
(x_orig, coeffs, new_direction) = ctx.saved_tensors
x_orig.requires_grad = True
num_channels = x_orig.shape[ctx.channel_dim]
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
new_direction.requires_grad = False
x = x - x.mean(dim=0)
x_var = (x ** 2).mean()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual ** 2).mean()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction. This is to be minimized.
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
variance_proportion.backward()
x_orig_grad = x_orig.grad
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
return x_grad + x_extra_grad.detach(), None, None, None, None
def clip_grad(x: Tensor, limit: float):
return ClipGradFunction.apply(x, limit)
class BiasNormFunction(torch.autograd.Function): class BiasNormFunction(torch.autograd.Function):