diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f37530483..93d8a43bc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -446,23 +446,24 @@ class ZipformerEncoderLayer(nn.Module): src = src + src_att # convolution module - if torch.jit.is_scripting() or use_self_attn: - src = src + self.nonlin_attention_module(src, - attn_weights, - head_idx=0) + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.squeeze_excite1(src, attn_weights, head_idx=1) + src = src + self.squeeze_excite1(src, attn_weights, head_idx=0) if torch.jit.is_scripting() or use_self_attn: self_attn_output2 = self.self_attn.forward2(src, attn_weights) src = src + self_attn_output2 - if torch.jit.is_scripting() or random.random() > dynamic_dropout: - src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + # attention version of convolution module + if torch.jit.is_scripting() or use_self_attn: + src = src + self.nonlin_attention_module(src, + attn_weights, + head_idx=1) src = src + self.feed_forward3(src)