From fc74ff63fbc93d313af87bcc5b38eba0ee3331a3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Nov 2022 13:57:36 +0800 Subject: [PATCH] Remove one feedforward module and give params to the other 2. --- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../ASR/pruned_transducer_stateless7/zipformer.py | 11 ++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 58f61d8a4..4969ed96d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="1024,1024,1536,1024,1024,1024", + default="1536,1536,2048,1536,1536,1536", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 053c54ce9..0d275832b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -400,10 +400,6 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(embed_dim, - feedforward_dim, - dropout) - #self.conv_module1 = ConvolutionModule(embed_dim, #cnn_module_kernel) self.nonlin_attention_module = NonlinAttentionModule(embed_dim) @@ -469,9 +465,6 @@ class ZipformerEncoderLayer(nn.Module): src_orig = src - # macaron style feed forward module - src = src + self.feed_forward1(src) - # dropout rate for non-feedforward submodules dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0 # multi-headed self-attention module @@ -490,7 +483,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.nonlin_attention_module(src, attn_weights[0:1]) - src = src + self.feed_forward2(src) + src = src + self.feed_forward1(src) # pooling module if torch.jit.is_scripting() or use_self_attn: @@ -503,7 +496,7 @@ class ZipformerEncoderLayer(nn.Module): if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob: src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - src = src + self.feed_forward3(src) + src = src + self.feed_forward2(src) # pooling module if torch.jit.is_scripting() or use_self_attn: