diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4ce8f5e30..f270a48fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -276,7 +276,7 @@ class Zipformer2(EncoderInterface): Args: x: the embeddings (needed for the shape and dtype and device), of shape - (num_frames, batch_size, encoder_dims0) + (1, batch_size, encoder_dims0) """ num_encoders = len(self.encoder_dim) if not self.training: @@ -286,38 +286,34 @@ class Zipformer2(EncoderInterface): assert self.encoder_dim[0] == _encoder_dims0 + feature_mask_dropout_prob = 0.125 - # this is a heuristic, the regions over which we masked frames are 4 times the - # largest of the downsampling_factors. Can be any integer >= 1. - downsampling_multiple = 4 + # mask1 shape: (1, batch_size, 1) + mask1 = (torch.rand(1, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) - group_size = max(self.downsampling_factor) * downsampling_multiple + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and(mask1, + (torch.rand(1, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) - num_groups = (num_frames0 + group_size - 1) // group_size - - feature_mask_dropout_prob = 0.2 - - # shape: (num_groups, batch_size, 1) - group_mask = (torch.rand(num_groups, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) feature_masks = [] for i in range(num_encoders): - ds = self.downsampling_factor[i] - frames_per_group = (group_size // ds) - - frame_mask = (group_mask.unsqueeze(1).expand(num_groups, frames_per_group, - batch_size, 1) - .reshape(num_groups * frames_per_group, batch_size, 1)) - num_frames = (num_frames0 + ds - 1) // ds - frame_mask = frame_mask[:num_frames] channels = self.encoder_dim[i] - feature_mask = torch.ones(num_frames, batch_size, channels, - dtype=x.dtype, device=x.device) + feature_mask = torch.ones(1, batch_size, channels, + dtype=x.dtype, device=x.device) u1 = self.encoder_unmasked_dim[i] - feature_mask[:, :, u1:] *= frame_mask + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + feature_masks.append(feature_mask) return feature_masks