diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1c181b1ac..bfb122d69 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -331,8 +331,6 @@ class ZipformerEncoderLayer(nn.Module): d_model, attention_dim, nhead, pos_dim, dropout=0.0, ) - self.squeeze_excite = ModifiedSEModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) @@ -351,6 +349,9 @@ class ZipformerEncoderLayer(nn.Module): self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) + + self.squeeze_excite = ModifiedSEModule(d_model) + self.norm_final = BasicNorm(d_model) self.bypass_scale = nn.Parameter(torch.tensor(0.5)) @@ -430,11 +431,6 @@ class ZipformerEncoderLayer(nn.Module): # dropout rate for submodules that interact with time. dynamic_dropout = self.get_dynamic_dropout_rate() - # pooling module - if torch.jit.is_scripting() or random.random() > dynamic_dropout: - src = src + self.squeeze_excite(src, - key_padding_mask=src_key_padding_mask) - # multi-headed self-attention module use_self_attn = (random.random() > dynamic_dropout) if torch.jit.is_scripting() or use_self_attn: @@ -460,6 +456,11 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward3(src) + # pooling module + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.squeeze_excite(src, + key_padding_mask=src_key_padding_mask) + src = self.norm_final(self.balancer(src)) delta = src - src_orig