diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 93257fdb6..828dd0fed 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -287,44 +287,23 @@ def softmax(x: Tensor, return SoftmaxFunction.apply(x, dim) -class MaxEigLimiterFunction(torch.autograd.Function): +class ClipGradFunction(torch.autograd.Function): @staticmethod def forward( ctx, x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + limit: float): + ctx.limit = limit return x - @staticmethod def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x ** 2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) - return x_grad + x_extra_grad.detach(), None, None, None, None + return x_grad.clamp(-ctx.limit, ctx.limit), None +def clip_grad(x: Tensor, limit: float): + return ClipGradFunction.apply(x, limit) + class BiasNormFunction(torch.autograd.Function):