Add clip_grad

This commit is contained in:
Daniel Povey 2023-05-23 14:00:56 +08:00
parent 3351402875
commit bcc9971ebe

View File

@ -287,44 +287,23 @@ def softmax(x: Tensor,
return SoftmaxFunction.apply(x, dim)
class MaxEigLimiterFunction(torch.autograd.Function):
class ClipGradFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
coeffs: Tensor,
direction: Tensor,
channel_dim: int,
grad_scale: float) -> Tensor:
ctx.channel_dim = channel_dim
ctx.grad_scale = grad_scale
ctx.save_for_backward(x.detach(),
coeffs.detach(),
direction.detach())
limit: float):
ctx.limit = limit
return x
@staticmethod
def backward(ctx, x_grad, *args):
with torch.enable_grad():
(x_orig, coeffs, new_direction) = ctx.saved_tensors
x_orig.requires_grad = True
num_channels = x_orig.shape[ctx.channel_dim]
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
new_direction.requires_grad = False
x = x - x.mean(dim=0)
x_var = (x ** 2).mean()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual ** 2).mean()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction. This is to be minimized.
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
variance_proportion.backward()
x_orig_grad = x_orig.grad
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
return x_grad + x_extra_grad.detach(), None, None, None, None
return x_grad.clamp(-ctx.limit, ctx.limit), None
def clip_grad(x: Tensor, limit: float):
return ClipGradFunction.apply(x, limit)
class BiasNormFunction(torch.autograd.Function):