diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index a790ce345..c5b748480 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -430,51 +430,6 @@ class MaxEigLimiterFunction(torch.autograd.Function): return x_grad + x_extra_grad.detach(), None, None, None, None -class BasicNormFunction(torch.autograd.Function): - # This computes: - # scales = torch.mean((x - bias) ** 2, keepdim=True) + eps.exp() - # return (x - bias) * scales - # (after unsqueezing the bias), but it does it in a memory-efficient way so that - # it can just store the returned value (chances are, this will also be needed for - # some other reason, related to the next operation, so we can save memory). - @staticmethod - @custom_fwd - def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: - assert bias.ndim == 1 - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - ctx.store_output_for_backprop = store_output_for_backprop - ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 - ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), eps.detach()) - return ans - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, eps = ctx.saved_tensors - if ctx.store_output_for_backprop: - x = ans_or_x / scales - else: - x = ans_or_x - x = x.detach() - x.requires_grad = True - bias.requires_grad = True - eps.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and eps. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5 - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), eps.grad, None, None - - - class BasicNorm(torch.nn.Module): """ This is intended to be a simpler, and hopefully cheaper, replacement for @@ -491,7 +446,7 @@ class BasicNorm(torch.nn.Module): Args: num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, + channel_dim: the axis/dimension corresponding to the channel, interprted as an offset from the input's ndim if negative. shis is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. @@ -501,13 +456,10 @@ class BasicNorm(torch.nn.Module): to indicate the connection with conventional LayerNorm. learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. - store_output_for_backprop: this option makes no difference - to the output, but may affect memory usage; determines - whether, for backprop purposes, we store the input or the output - of this module. eps_min: float eps_max: float """ + def __init__( self, num_channels: int, @@ -516,7 +468,6 @@ class BasicNorm(torch.nn.Module): learn_eps: bool = True, eps_min: float = -3.0, eps_max: float = 3.0, - store_output_for_backprop: bool = True ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels @@ -525,24 +476,12 @@ class BasicNorm(torch.nn.Module): self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: self.register_buffer("eps", torch.tensor(eps).log().detach()) - self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps_min = eps_min self.eps_max = eps_max - self.store_output_for_backprop = store_output_for_backprop def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels eps = self.eps - - if torch.jit.is_scripting(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 - return x * scales - if self.training and random.random() < 0.25: # with probability 0.25, in training mode, clamp eps between the min # and max; this will encourage it to learn parameters within the @@ -552,9 +491,10 @@ class BasicNorm(torch.nn.Module): # gradients to allow the parameter to get back into the allowed # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) - - return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim, - self.store_output_for_backprop) + scales = ( + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + ) ** -0.5 + return x * scales diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 2470c25c8..e4385f87a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -451,7 +451,7 @@ class ZipformerEncoderLayer(nn.Module): self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) - self.norm_final = BasicNorm(embed_dim, store_output_for_backprop=False) + self.norm_final = BasicNorm(embed_dim) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))