mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove one feedforward module and give params to the other 2.
This commit is contained in:
parent
3d47335ab6
commit
fc74ff63fb
@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--feedforward-dim",
|
"--feedforward-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="1024,1024,1536,1024,1024,1024",
|
default="1536,1536,2048,1536,1536,1536",
|
||||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -400,10 +400,6 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
feedforward_dim,
|
feedforward_dim,
|
||||||
dropout)
|
dropout)
|
||||||
|
|
||||||
self.feed_forward3 = FeedforwardModule(embed_dim,
|
|
||||||
feedforward_dim,
|
|
||||||
dropout)
|
|
||||||
|
|
||||||
#self.conv_module1 = ConvolutionModule(embed_dim,
|
#self.conv_module1 = ConvolutionModule(embed_dim,
|
||||||
#cnn_module_kernel)
|
#cnn_module_kernel)
|
||||||
self.nonlin_attention_module = NonlinAttentionModule(embed_dim)
|
self.nonlin_attention_module = NonlinAttentionModule(embed_dim)
|
||||||
@ -469,9 +465,6 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
# macaron style feed forward module
|
|
||||||
src = src + self.feed_forward1(src)
|
|
||||||
|
|
||||||
# dropout rate for non-feedforward submodules
|
# dropout rate for non-feedforward submodules
|
||||||
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
|
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
@ -490,7 +483,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.nonlin_attention_module(src,
|
src = src + self.nonlin_attention_module(src,
|
||||||
attn_weights[0:1])
|
attn_weights[0:1])
|
||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
@ -503,7 +496,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
|
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward3(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user