diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 742c314fa..f43fae528 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -382,8 +382,7 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales