diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 270cd2044..f370f3386 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -439,23 +439,29 @@ class BasicNormFunction(torch.autograd.Function): # some other reason, related to the next operation, so we can save memory). @staticmethod @custom_fwd - def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int) -> Tensor: + def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int, + store_output_for_backprop: bool) -> Tensor: assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop ctx.channel_dim = channel_dim for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 ans = x * scales - ctx.save_for_backward(ans, scales, bias, eps) + ctx.save_for_backward(ans if store_output_for_backprop else x, + scales, bias, eps) return ans @staticmethod @custom_bwd def backward(ctx, ans_grad: Tensor) -> Tensor: - ans, scales, bias, eps = ctx.saved_tensors - x = ans / scales + ans_or_x, scales, bias, eps = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x x = x.detach() bias = bias.detach() eps = eps.detach() @@ -467,7 +473,7 @@ class BasicNormFunction(torch.autograd.Function): scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5 ans = x * scales ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), eps.grad, None + return x.grad, bias.grad.flatten(), eps.grad, None, None @@ -497,10 +503,13 @@ class BasicNorm(torch.nn.Module): to indicate the connection with conventional LayerNorm. learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. + store_output_for_backprop: this option makes no difference + to the output, but may affect memory usage; determines + whether, for backprop purposes, we store the input or the output + of this module. eps_min: float eps_max: float """ - def __init__( self, num_channels: int, @@ -509,6 +518,7 @@ class BasicNorm(torch.nn.Module): learn_eps: bool = True, eps_min: float = -3.0, eps_max: float = 3.0, + store_output_for_backprop: bool = True ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels @@ -520,6 +530,7 @@ class BasicNorm(torch.nn.Module): self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps_min = eps_min self.eps_max = eps_max + self.store_output_for_backprop = store_output_for_backprop def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels @@ -544,7 +555,8 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) - return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim) + return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim, + self.store_output_for_backprop) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 3082edbff..ca15f0c8d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -451,7 +451,7 @@ class ZipformerEncoderLayer(nn.Module): self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) - self.norm_final = BasicNorm(embed_dim) + self.norm_final = BasicNorm(embed_dim, store_output_for_backprop=False) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))