diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 40a493599..270cd2044 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -430,32 +430,44 @@ class MaxEigLimiterFunction(torch.autograd.Function): return x_grad + x_extra_grad.detach(), None, None, None, None -class ComputeSquaredMeanWithOffset(torch.autograd.Function): +class BasicNormFunction(torch.autograd.Function): + # This computes: + # scales = torch.mean((x + bias) ** 2, keepdim=True) + eps.exp() + # return x * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). @staticmethod @custom_fwd - def forward(ctx, x: Tensor, bias: Tensor, channel_dim: int) -> Tensor: + def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int) -> Tensor: assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.channel_dim = channel_dim for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - ans = torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) - ctx.save_for_backward(x, bias) + scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 + ans = x * scales + ctx.save_for_backward(ans, scales, bias, eps) return ans @staticmethod @custom_bwd def backward(ctx, ans_grad: Tensor) -> Tensor: - x, bias = ctx.saved_tensors + ans, scales, bias, eps = ctx.saved_tensors + x = ans / scales x = x.detach() bias = bias.detach() + eps = eps.detach() x.requires_grad = True bias.requires_grad = True + eps.requires_grad = True with torch.enable_grad(): - ans = torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + # recompute scales from x, bias and eps. + scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5 + ans = x * scales ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), None + return x.grad, bias.grad.flatten(), eps.grad, None @@ -512,6 +524,16 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels eps = self.eps + + if torch.jit.is_scripting(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 + return x * scales + if self.training and random.random() < 0.25: # with probability 0.25, in training mode, clamp eps between the min # and max; this will encourage it to learn parameters within the @@ -522,11 +544,7 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) - norms = ComputeSquaredMeanWithOffset.apply(x, self.bias, self.channel_dim) - scales = ( - norms + eps.exp() - ) ** -0.5 - return x * scales + return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim)