diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f774e5548..3102bf84d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -297,30 +297,25 @@ class Zipformer2(EncoderInterface): num_frames_max = (num_frames0 + max_downsampling_factor - 1) # we divide the dropped-out feature dimensions into two equal groups; - # the first group is dropped out with this probability, the second - # group is dropped out with about twice this probability. - feature_mask_dropout_prob = 0.1 + # the first group is dropped out with probability 0.05, the second + # with probability approximately (0.2 + 0.05). + feature_mask_dropout_prob1 = 0.05 + feature_mask_dropout_prob2 = 0.2 # frame_mask_max1 shape: (num_frames_max, batch_size, 1) frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1, device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + feature_mask_dropout_prob1).to(x.dtype) # frame_mask_max2 has additional frames masked, about twice the number. frame_mask_max2 = torch.logical_and(frame_mask_max1, (torch.rand(num_frames_max, batch_size, 1, device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) + feature_mask_dropout_prob2).to(x.dtype)) - # frame_mask_max3 has additional frames masked, about 3 times the number. - frame_mask_max3 = torch.logical_and(frame_mask_max2, - (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) - # dim: (num_frames_max, batch_size, 3) - frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2, frame_mask_max3), dim=-1) + frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1) feature_masks = [] for i in range(num_encoders): @@ -328,19 +323,17 @@ class Zipformer2(EncoderInterface): upsample_factor = (max_downsampling_factor // ds) frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 3) - .reshape(num_frames_max * upsample_factor, batch_size, 3)) + batch_size, 2) + .reshape(num_frames_max * upsample_factor, batch_size, 2)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] channels = self.encoder_dim[i] feature_mask = torch.ones(num_frames, batch_size, channels, dtype=x.dtype, device=x.device) u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 3 - u3 = u1 + 2 * (channels - u1) // 3 + u2 = u1 + (channels - u1) // 2 feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1] - feature_mask[:, :, u2:u3] *= frame_mask[..., 1:2] - feature_mask[:, :, u3:channels] *= frame_mask[..., 2:3] + feature_mask[:, :, u2:] *= frame_mask[..., 1:2] feature_masks.append(feature_mask) return feature_masks