diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 35daf4f39..3127c5d2a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -389,10 +389,7 @@ class ZipformerEncoderLayer(nn.Module): dropout=0.0, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, + self.self_attn = SelfAttention(embed_dim, num_heads, value_head_dim) self.feed_forward1 = FeedforwardModule(embed_dim, @@ -489,11 +486,6 @@ class ZipformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, ) - if torch.jit.is_scripting() or use_self_attn: - src = src + self.self_attn1( - src, attn_weights) - - # convolution module if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, attn_weights[0:1]) @@ -505,7 +497,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.attention_squeeze1(src, attn_weights[1:2]) if torch.jit.is_scripting() or use_self_attn: - src = src + self.self_attn2( + src = src + self.self_attn( src, attn_weights) if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob: