Revert model-size changes

This commit is contained in:
Daniel Povey 2023-05-30 14:49:42 +08:00
parent d0309c3f3d
commit 7d7fc45ab2

View File

@ -122,28 +122,28 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="2,4,4,4,8,4,4,4,2", default="2,4,4,8,4,4,2",
help="Number of subformer encoder layers per stack, comma separated.", help="Number of subformer encoder layers per stack, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
default="1024,1536,2048,3072,4096,3072,2048,1536,1024", default="1024,1536,2048,3072,2048,1536,1024",
help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.", help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--num-heads", "--num-heads",
type=str, type=str,
default="4,4,8,16,16,16,8,4,4", default="4,4,8,16,8,4,4",
help="Number of attention heads in the subformer encoder layers: a single int or comma-separated list.", help="Number of attention heads in the subformer encoder layers: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=str, type=str,
default="256,384,512,768,768,768,512,384,256", default="256,384,512,768,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list." help="Embedding dimension in encoder stacks: a single int or comma-separated list."
) )
@ -158,7 +158,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-structure", "--encoder-structure",
type=str, type=str,
default="S(S(S(S(S)S)S)S)S", default="S(S(S(S)S)S)S",
help="Structure of encoder, determines order of encoder stacks and (downsampling/upsampling) " help="Structure of encoder, determines order of encoder stacks and (downsampling/upsampling) "
"operations." "operations."
) )
@ -404,7 +404,7 @@ def get_params() -> AttributeDict:
"warm_step": 2000, "warm_step": 2000,
"env_info": get_env_info(), "env_info": get_env_info(),
"bytes_per_segment": 2048, "bytes_per_segment": 2048,
"batch_size": 15, "batch_size": 18,
"train_file_list": "train.txt", "train_file_list": "train.txt",
"valid_file_list": "valid.txt", "valid_file_list": "valid.txt",
"num_workers": 4, "num_workers": 4,