Have 6 different encoder stacks, U-shaped network.

This commit is contained in:
Daniel Povey 2022-10-28 20:36:45 +08:00
parent 7b57a34227
commit 96ea4cf1be

View File

@ -93,35 +93,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="2,4,4,4", default="2,3,3,3,3,3",
help="Number of zipformer encoder layers, comma separated.", help="Number of zipformer encoder layers, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--feedforward-dims", "--feedforward-dims",
type=str, type=str,
default="1024,1536,1536,1536", default="1024,1024,1536,1536,1536,1024",
help="Feedforward dimension of the zipformer encoder layers, comma separated.", help="Feedforward dimension of the zipformer encoder layers, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--nhead", "--nhead",
type=str, type=str,
default="8,8,8,8", default="8,8,8,8,8,8",
help="Number of attention heads in the zipformer encoder layers.", help="Number of attention heads in the zipformer encoder layers.",
) )
parser.add_argument( parser.add_argument(
"--encoder-dims", "--encoder-dims",
type=str, type=str,
default="384,384,384,512", default="384,384,384,384,384,384",
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
) )
parser.add_argument( parser.add_argument(
"--attention-dims", "--attention-dims",
type=str, type=str,
default="192,192,192,256", default="192,192,192,192,192,192",
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension.""" not the same as embedding dimension."""
) )
@ -129,7 +129,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-unmasked-dims", "--encoder-unmasked-dims",
type=str, type=str,
default="256,256,256,256", default="256,256,256,256,256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse." " worse."
@ -138,14 +138,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--zipformer-downsampling-factors", "--zipformer-downsampling-factors",
type=str, type=str,
default="1,2,4,8", default="1,2,4,8,4,2",
help="Downsampling factor for each stack of encoder layers.", help="Downsampling factor for each stack of encoder layers.",
) )
parser.add_argument( parser.add_argument(
"--cnn-module-kernels", "--cnn-module-kernels",
type=str, type=str,
default="31,31,31,31", default="31,31,31,31,31,31",
help="Sizes of kernels in convolution modules", help="Sizes of kernels in convolution modules",
) )