From a27670d0972405d23272d379dc454dfe6358979e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Nov 2022 12:41:19 +0800 Subject: [PATCH] Restore feedforward3 module --- .../ASR/pruned_transducer_stateless7/zipformer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9ccdaab99..1c181b1ac 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -338,8 +338,12 @@ class ZipformerEncoderLayer(nn.Module): dropout) self.feed_forward2 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + feedforward_dim, + dropout) + + self.feed_forward3 = FeedforwardModule(d_model, + feedforward_dim, + dropout) self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) @@ -446,13 +450,15 @@ class ZipformerEncoderLayer(nn.Module): if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward2(src) + if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn.forward2(src, attn_weights) if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) - src = src + self.feed_forward2(src) + src = src + self.feed_forward3(src) src = self.norm_final(self.balancer(src))