arch align

This commit is contained in:
jinzr 2023-10-19 11:09:20 +08:00
parent f3f0dfc52d
commit f3b918452a

View File

@ -127,35 +127,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="2,2,3,4,3,2", default="2,4,3,2,4",
help="Number of zipformer encoder layers per stack, comma separated.", help="Number of zipformer encoder layers per stack, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--downsampling-factor", "--downsampling-factor",
type=str, type=str,
default="1,2,4,8,4,2", default="1,2,4,8,2",
help="Downsampling factor for each stack of encoder layers.", help="Downsampling factor for each stack of encoder layers.",
) )
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
default="512,768,1024,1536,1024,768", default="1024,1024,2048,2048,1024",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--num-heads", "--num-heads",
type=str, type=str,
default="4,4,4,8,4,4", default="8,8,8,8,8",
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=str, type=str,
default="192,256,384,512,384,256", default="384,384,384,384,384",
help="Embedding dimension in encoder stacks: a single int or comma-separated list.", help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
) )
@ -190,7 +190,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-unmasked-dim", "--encoder-unmasked-dim",
type=str, type=str,
default="192,192,256,256,256,192", default="256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.", "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
) )
@ -198,7 +198,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--cnn-module-kernel", "--cnn-module-kernel",
type=str, type=str,
default="31,31,15,15,15,31", default="31,31,31,31,31",
help="Sizes of convolutional kernels in convolution modules in each encoder stack: " help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.", "a single int or comma-separated list.",
) )
@ -329,7 +329,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=0.045, help="The base learning rate." "--base-lr", type=float, default=0.05, help="The base learning rate."
) )
parser.add_argument( parser.add_argument(