diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 374470ff7..32c61e81d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -401,7 +401,9 @@ class ConformerEncoder(nn.Module): batch = self.count.item() if self.training: self.count += 1 - return min(1.0, batch / self.warmup_batches) + return min(1.0, batch / self.warmup_batches) + else: + return 1.0 # this is mostly a workaround for an issue with moderl averaging. def forward(