diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 8778dc5ba..83de82056 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -221,7 +221,10 @@ class ConformerEncoderLayer(nn.Module): warmup_scale = min(0.1 + warmup, 1.0) # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely # bypass it. - alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 + if self.training: + alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 + else: + alpha = 1.0 # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src))