mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Move ModifiedSEModule to end of ZipformerEncoderLayer
This commit is contained in:
parent
a27670d097
commit
44bdda1218
@ -331,8 +331,6 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
|
||||
)
|
||||
|
||||
self.squeeze_excite = ModifiedSEModule(d_model)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(d_model,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
@ -351,6 +349,9 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
self.conv_module2 = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
|
||||
|
||||
self.squeeze_excite = ModifiedSEModule(d_model)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
self.bypass_scale = nn.Parameter(torch.tensor(0.5))
|
||||
@ -430,11 +431,6 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# dropout rate for submodules that interact with time.
|
||||
dynamic_dropout = self.get_dynamic_dropout_rate()
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.squeeze_excite(src,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
|
||||
# multi-headed self-attention module
|
||||
use_self_attn = (random.random() > dynamic_dropout)
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
@ -460,6 +456,11 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
src = src + self.feed_forward3(src)
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.squeeze_excite(src,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
delta = src - src_orig
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user