diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 538646948..f774e5548 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -307,17 +307,17 @@ class Zipformer2(EncoderInterface): feature_mask_dropout_prob).to(x.dtype) # frame_mask_max2 has additional frames masked, about twice the number. - frame_mask_max2 = torch.logical_or(frame_mask_max1, - (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) + frame_mask_max2 = torch.logical_and(frame_mask_max1, + (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) # frame_mask_max3 has additional frames masked, about 3 times the number. - frame_mask_max3 = torch.logical_or(frame_mask_max2, - (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) + frame_mask_max3 = torch.logical_and(frame_mask_max2, + (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) # dim: (num_frames_max, batch_size, 3) frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2, frame_mask_max3), dim=-1)