From 22a1401f36712177875700b615e932bcb63cfec9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Nov 2022 14:31:37 +0800 Subject: [PATCH] Remove self_attn1 module --- .../ASR/pruned_transducer_stateless7/zipformer.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 35daf4f39..3127c5d2a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -389,10 +389,7 @@ class ZipformerEncoderLayer(nn.Module): dropout=0.0, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, + self.self_attn = SelfAttention(embed_dim, num_heads, value_head_dim) self.feed_forward1 = FeedforwardModule(embed_dim, @@ -489,11 +486,6 @@ class ZipformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, ) - if torch.jit.is_scripting() or use_self_attn: - src = src + self.self_attn1( - src, attn_weights) - - # convolution module if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, attn_weights[0:1]) @@ -505,7 +497,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.attention_squeeze1(src, attn_weights[1:2]) if torch.jit.is_scripting() or use_self_attn: - src = src + self.self_attn2( + src = src + self.self_attn( src, attn_weights) if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob: