diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 4c3271ac6..b21b531d0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -610,7 +610,8 @@ class ConvNorm1d(torch.nn.Module): counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0) sqnorms = sqnorms * counts sqnorms = self.conv(sqnorms) - counts = self.conv(counts) + # the clamping is to avoid division by zero for padding frames. + counts = self.conv(counts).clamp(min=0.01) # scales: (N, 1, T) scales = (sqnorms / counts + eps.exp()) ** -0.5 return x * scales