diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 944e010e8..2cbb9c570 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -163,15 +163,10 @@ class Conformer(EncoderInterface): # self.encoder_unmasked_dim. feature_mask_dropout_prob = 0.15 - # we only apply the random frame masking on 90% of sequences; we leave the remaining 10% - # un-masked so that the model has seen un-masked data. - sequence_mask_dropout_prob = 0.9 - - # frame_mask is 0 with probability `feature_mask_dropout_prob` # frame_mask1 shape: (num_frames1, batch_size, 1) - frame_mask1 = torch.logical_or( - torch.rand(num_frames1, batch_size, 1, device=x.device) > feature_mask_dropout_prob, - torch.rand(1, batch_size, 1, device=x.device) > sequence_mask_dropout_prob).to(x.dtype) + frame_mask1 = (torch.rand(num_frames1, batch_size, 1, device=x.device) > + feature_mask_dropout_prob).to(x.dtype) + feature_mask1 = torch.ones(num_frames1, batch_size, self.d_model[1], dtype=x.dtype, device=x.device)