Change order of convolution and nonlin-attention modules
This commit is contained in:
parent
36bff9b369
commit
797a0e6ce7
@ -446,23 +446,24 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src = src + src_att
|
||||
|
||||
# convolution module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
attn_weights,
|
||||
head_idx=0)
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.squeeze_excite1(src, attn_weights, head_idx=1)
|
||||
src = src + self.squeeze_excite1(src, attn_weights, head_idx=0)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
self_attn_output2 = self.self_attn.forward2(src, attn_weights)
|
||||
src = src + self_attn_output2
|
||||
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
# attention version of convolution module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
attn_weights,
|
||||
head_idx=1)
|
||||
|
||||
src = src + self.feed_forward3(src)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user