diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index a14fd9137..bfd88f825 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -296,27 +296,38 @@ class Zipformer2(EncoderInterface): num_frames_max = (num_frames0 + max_downsampling_factor - 1) - feature_mask_dropout_prob = 0.15 + feature_mask_dropout_prob = 0.125 - # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, + # frame_mask_max1 shape: (num_frames_max, batch_size, 1) + frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1, device=x.device) > feature_mask_dropout_prob).to(x.dtype) + # frame_mask_max2 has additional frames masked, about twice the number. + frame_mask_max2 = torch.logical_or(frame_mask_max1, + (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) + + # dim: (num_frames_max, batch_size, 2) + frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1) + feature_masks = [] for i in range(num_encoders): ds = self.downsampling_factor[i] upsample_factor = (max_downsampling_factor // ds) frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1)) + batch_size, 2) + .reshape(num_frames_max * upsample_factor, batch_size, 2)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] feature_mask = torch.ones(num_frames, batch_size, self.encoder_dim[i], dtype=x.dtype, device=x.device) - u = self.encoder_unmasked_dim[i] - feature_mask[:, :, u:] *= frame_mask + u1 = self.encoder_unmasked_dim[i] + u2 = (u1 + self.encoder_dim[i]) // 2 + feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1] + feature_mask[:, :, u2:] *= frame_mask[..., 1:2] feature_masks.append(feature_mask) return feature_masks