mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Rework zipformer code for clarity and extensibility
This commit is contained in:
parent
797a0e6ce7
commit
20e6d2a157
@ -105,59 +105,81 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="2,4,3,2,2,4",
|
||||
help="Number of zipformer encoder layers, comma separated.",
|
||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feedforward-dims",
|
||||
type=str,
|
||||
default="1024,1024,1536,1536,1536,1024",
|
||||
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nhead",
|
||||
type=str,
|
||||
default="8,8,8,8,8,8",
|
||||
help="Number of attention heads in the zipformer encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dims",
|
||||
type=str,
|
||||
default="384,384,384,384,384,384",
|
||||
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-dims",
|
||||
type=str,
|
||||
default="192,192,192,192,192,192",
|
||||
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
||||
not the same as embedding dimension."""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dims",
|
||||
type=str,
|
||||
default="256,256,256,256,256,256",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
||||
" worse."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--zipformer-downsampling-factors",
|
||||
"--downsampling-factor",
|
||||
type=str,
|
||||
default="1,2,4,8,4,2",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--cnn-module-kernels",
|
||||
"--feedforward-dim",
|
||||
type=str,
|
||||
default="31,31,31,31,31,31",
|
||||
help="Sizes of kernels in convolution modules",
|
||||
default="1024,1024,1536,1536,1536,1024",
|
||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-heads",
|
||||
type=str,
|
||||
default="8",
|
||||
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
default="384",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
type=str,
|
||||
default="24",
|
||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--value-head-dim",
|
||||
type=str,
|
||||
default="12",
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-head-dim",
|
||||
type=str,
|
||||
default="4",
|
||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-dim",
|
||||
type=int,
|
||||
default="192",
|
||||
help="Positional-encoding embedding dimension"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=str,
|
||||
default="256",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cnn-module-kernel",
|
||||
type=str,
|
||||
default="31",
|
||||
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -455,14 +477,19 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Zipformer(
|
||||
num_features=params.feature_dim,
|
||||
output_downsampling_factor=2,
|
||||
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
|
||||
encoder_dims=to_int_tuple(params.encoder_dims),
|
||||
attention_dim=to_int_tuple(params.attention_dims),
|
||||
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
||||
nhead=to_int_tuple(params.nhead),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dims),
|
||||
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
|
||||
downsampling_factor=to_int_tuple(params.downsampling_factor),
|
||||
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
||||
encoder_dim=to_int_tuple(params.encoder_dim),
|
||||
encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim),
|
||||
query_head_dim=to_int_tuple(params.query_head_dim),
|
||||
pos_head_dim=to_int_tuple(params.pos_head_dim),
|
||||
value_head_dim=to_int_tuple(params.value_head_dim),
|
||||
pos_dim=params.pos_dim,
|
||||
num_heads=to_int_tuple(params.num_heads),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dim),
|
||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
||||
dropout=0.1,
|
||||
warmup_batches=4000.0,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -479,7 +506,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -496,7 +523,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user