diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0a25ca8f4..2c215f0f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -292,7 +292,14 @@ class Zipformer2(EncoderInterface): assert self.encoder_dim[0] == _encoder_dims0 - max_downsampling_factor = max(self.downsampling_factor) + + # this is a heuristic, the regions over which we masked frames are 4 times the + # largest of the downsampling_factors. Can be any integer >= 1. + downsampling_multiple = 4 + + downsampling_factor = [ downsampling_multiple * i for i in self.downsampling_factor ] + + max_downsampling_factor = max(downsampling_factor) num_frames_max = (num_frames0 + max_downsampling_factor - 1) @@ -318,7 +325,7 @@ class Zipformer2(EncoderInterface): feature_masks = [] for i in range(num_encoders): - ds = self.downsampling_factor[i] + ds = downsampling_factor[i] upsample_factor = (max_downsampling_factor // ds) frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,