diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 04a2822ee..40a493599 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -430,6 +430,35 @@ class MaxEigLimiterFunction(torch.autograd.Function): return x_grad + x_extra_grad.detach(), None, None, None, None +class ComputeSquaredMeanWithOffset(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor, bias: Tensor, channel_dim: int) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + ans = torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + ctx.save_for_backward(x, bias) + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, bias = ctx.saved_tensors + x = x.detach() + bias = bias.detach() + x.requires_grad = True + bias.requires_grad = True + with torch.enable_grad(): + ans = torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), None + + + class BasicNorm(torch.nn.Module): """ This is intended to be a simpler, and hopefully cheaper, replacement for @@ -446,7 +475,7 @@ class BasicNorm(torch.nn.Module): Args: num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, + channel_dim: the axis/dimension corresponding to the channel, interprted as an offset from the input's ndim if negative. shis is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. @@ -476,6 +505,7 @@ class BasicNorm(torch.nn.Module): self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps_min = eps_min self.eps_max = eps_max @@ -491,8 +521,10 @@ class BasicNorm(torch.nn.Module): # gradients to allow the parameter to get back into the allowed # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) + + norms = ComputeSquaredMeanWithOffset.apply(x, self.bias, self.channel_dim) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + norms + eps.exp() ) ** -0.5 return x * scales @@ -1071,12 +1103,10 @@ class Whiten(nn.Module): self.grad_scale = grad_scale if isinstance(prob, float): - assert 0 < prob <= 1 - self.prob = prob - else: - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob < self.max_prob <= 1 - self.prob = self.max_prob + prob = (prob, prob) + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob <= self.max_prob <= 1 + self.prob = self.max_prob self.name = None # will be set in training loop def forward(self, @@ -1718,7 +1748,7 @@ def _test_activation_balancer_magnitude(): max_factor=0.2, min_abs=0.2, max_abs=0.8, - min_prob=1.0, + prob=1.0, ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))