mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove self_attn1 module
This commit is contained in:
parent
d542fa61ff
commit
22a1401f36
@ -389,10 +389,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
self.self_attn1 = SelfAttention(embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
self.self_attn2 = SelfAttention(embed_dim, num_heads,
|
||||
self.self_attn = SelfAttention(embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(embed_dim,
|
||||
@ -489,11 +486,6 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn1(
|
||||
src, attn_weights)
|
||||
|
||||
# convolution module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
attn_weights[0:1])
|
||||
@ -505,7 +497,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src = src + self.attention_squeeze1(src, attn_weights[1:2])
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn2(
|
||||
src = src + self.self_attn(
|
||||
src, attn_weights)
|
||||
|
||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user