Remove self_attn1 module

This commit is contained in:
Daniel Povey 2022-11-15 14:31:37 +08:00
parent d542fa61ff
commit 22a1401f36

View File

@ -389,10 +389,7 @@ class ZipformerEncoderLayer(nn.Module):
dropout=0.0,
)
self.self_attn1 = SelfAttention(embed_dim, num_heads,
value_head_dim)
self.self_attn2 = SelfAttention(embed_dim, num_heads,
self.self_attn = SelfAttention(embed_dim, num_heads,
value_head_dim)
self.feed_forward1 = FeedforwardModule(embed_dim,
@ -489,11 +486,6 @@ class ZipformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask,
)
if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn1(
src, attn_weights)
# convolution module
if torch.jit.is_scripting() or use_self_attn:
src = src + self.nonlin_attention_module(src,
attn_weights[0:1])
@ -505,7 +497,7 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.attention_squeeze1(src, attn_weights[1:2])
if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn2(
src = src + self.self_attn(
src, attn_weights)
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob: