From 2e0f4de8ffc545e4171a568bca8807cc908645ef Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 23 Dec 2022 15:59:51 +0800 Subject: [PATCH] Apply limit on BasicNorm.eps more effectively using limit_param_value; add final norm to Zipformer. --- .../ASR/pruned_transducer_stateless7/scaling.py | 12 +++--------- .../ASR/pruned_transducer_stateless7/zipformer.py | 2 ++ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 679fe552c..88dac0ac2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -481,16 +481,10 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - eps = self.eps - if self.training and random.random() < 0.25: - # with probability 0.25, in training mode, clamp eps between the min - # and max; this will encourage it to learn parameters within the - # allowed range by making parameters that are outside the allowed - # range noisy. - # gradients to allow the parameter to get back into the allowed - # region if it happens to exit it. - eps = eps.clamp(min=self.eps_min, max=self.eps_max) + eps = self.eps + if self.training: + eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max) eps = eps.exp() scales = ( (torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) / diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 25676801d..1165ffaab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -216,6 +216,7 @@ class Zipformer(EncoderInterface): encoder_dim[-1], downsample=output_downsampling_factor, dropout=dropout) + self.norm = BasicNorm(num_channels=encoder_dim[-1]) def _init_skip_modules(self): @@ -357,6 +358,7 @@ class Zipformer(EncoderInterface): lengths = (lengths + 1) // 2 x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = self.norm(x) return x, lengths