From f1dbf4222efc155cc10f84f8827c22693dfa0b57 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 29 Mar 2023 16:08:03 +0800 Subject: [PATCH] Divide feature_mask into 3 groups --- .../pruned_transducer_stateless7/zipformer.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fd8226cbc..37f4ecc3d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -299,7 +299,7 @@ class Zipformer2(EncoderInterface): # we divide the dropped-out feature dimensions into two equal groups; # the first group is dropped out with this probability, the second # group is dropped out with about twice this probability. - feature_mask_dropout_prob = 0.125 + feature_mask_dropout_prob = 0.1 # frame_mask_max1 shape: (num_frames_max, batch_size, 1) frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1, @@ -312,8 +312,15 @@ class Zipformer2(EncoderInterface): device=x.device) > feature_mask_dropout_prob).to(x.dtype)) - # dim: (num_frames_max, batch_size, 2) - frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1) + + # frame_mask_max3 has additional frames masked, about 3 times the number. + frame_mask_max3 = torch.logical_or(frame_mask_max2, + (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) + + # dim: (num_frames_max, batch_size, 3) + frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2, frame_mask_max3), dim=-1) feature_masks = [] for i in range(num_encoders): @@ -321,16 +328,19 @@ class Zipformer2(EncoderInterface): upsample_factor = (max_downsampling_factor // ds) frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 2) - .reshape(num_frames_max * upsample_factor, batch_size, 2)) + batch_size, 3) + .reshape(num_frames_max * upsample_factor, batch_size, 3)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dim[i], + channels = self.encoder_unmasked_dim[i] + feature_mask = torch.ones(num_frames, batch_size, channels, dtype=x.dtype, device=x.device) u1 = self.encoder_unmasked_dim[i] - u2 = (u1 + self.encoder_dim[i]) // 2 + u2 = u1 + (channels - u1) // 3 + u3 = u1 + 2 * (channels - u1) // 3 feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1] - feature_mask[:, :, u2:] *= frame_mask[..., 1:2] + feature_mask[:, :, u2:u3] *= frame_mask[..., 1:2] + feature_mask[:, :, u3:channels] *= frame_mask[..., 2:3] feature_masks.append(feature_mask) return feature_masks