diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 706c3c1c7..a199458f0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -450,8 +450,8 @@ class BasicNormFunction(torch.autograd.Function): bias = bias.unsqueeze(-1) scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 ans = x * scales - bias - ctx.save_for_backward(ans if store_output_for_backprop else x, - scales, bias, eps) + ctx.save_for_backward(ans.detach() if store_output_for_backprop else x.detach(), + scales.detach(), bias.detach(), eps.detach()) return ans @staticmethod @@ -463,8 +463,6 @@ class BasicNormFunction(torch.autograd.Function): else: x = ans_or_x x = x.detach() - bias = bias.detach() - eps = eps.detach() x.requires_grad = True bias.requires_grad = True eps.requires_grad = True diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ca15f0c8d..dec6142b8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1813,8 +1813,8 @@ class Conv2dSubsampling(nn.Module): self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), ConvNeXt(layer2_channels), - ConvNorm2d(layer2_channels, - kernel_size=(15, 7))) # (time, freq) + BasicNorm(layer2_channels, + channel_dim=1)) self.conv2 = nn.Sequential( @@ -1832,10 +1832,7 @@ class Conv2dSubsampling(nn.Module): self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), ConvNeXt(layer3_channels), - ConvNeXt(layer3_channels), - ConvNorm2d(layer3_channels, - kernel_size=(15, 5))) # (time, freq) - + ConvNeXt(layer3_channels)) out_height = (((in_channels - 1) // 2) - 1) // 2