From 6e058b9ebdc1f3a21b13c11091b7a1d226615e5d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 29 Mar 2023 23:59:40 +0800 Subject: [PATCH] Fix or vs. and bug --- .../pruned_transducer_stateless7/zipformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 538646948..f774e5548 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -307,17 +307,17 @@ class Zipformer2(EncoderInterface): feature_mask_dropout_prob).to(x.dtype) # frame_mask_max2 has additional frames masked, about twice the number. - frame_mask_max2 = torch.logical_or(frame_mask_max1, - (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) + frame_mask_max2 = torch.logical_and(frame_mask_max1, + (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) # frame_mask_max3 has additional frames masked, about 3 times the number. - frame_mask_max3 = torch.logical_or(frame_mask_max2, - (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) + frame_mask_max3 = torch.logical_and(frame_mask_max2, + (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) # dim: (num_frames_max, batch_size, 3) frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2, frame_mask_max3), dim=-1)