Reduce dimension for speed, have varying dims

This commit is contained in:
Daniel Povey 2023-01-12 21:15:39 +08:00
parent 9e4b84f374
commit 1e04c3d892

View File

@ -123,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="4,4,6,4", default="2,2,4,6,4,2",
help="Number of zipformer encoder layers per stack, comma separated.", help="Number of zipformer encoder layers per stack, comma separated.",
) )
@ -131,7 +131,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--downsampling-factor", "--downsampling-factor",
type=str, type=str,
default="1,2,4,2", default="1,2,4,8,4,2",
help="Downsampling factor for each stack of encoder layers.", help="Downsampling factor for each stack of encoder layers.",
) )
@ -139,14 +139,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
default="1536,1536,1536,1536", default="384,768,1024,1536,1024,768",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--num-heads", "--num-heads",
type=str, type=str,
default="8,8,8,8", default="4,4,4,8,4,4",
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
) )
@ -160,7 +160,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=str, type=str,
default="384", default="192,256,320,384,320,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list." help="Embedding dimension in encoder stacks: a single int or comma-separated list."
) )
@ -195,7 +195,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-unmasked-dim", "--encoder-unmasked-dim",
type=str, type=str,
default="256", default="164,192,256,256,256,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim." "A single int or comma-separated list. Must be <= each corresponding encoder_dim."
) )
@ -203,7 +203,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--cnn-module-kernel", "--cnn-module-kernel",
type=str, type=str,
default="31", default="31,31,15,15,15,31",
help="Sizes of convolutional kernels in convolution modules in each encoder stack: " help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.", "a single int or comma-separated list.",
) )