Increase batch size

This commit is contained in:
Daniel Povey 2023-05-16 12:13:13 +08:00
parent 8001a46758
commit 465d41c429

View File

@ -128,7 +128,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
default="512,768,1024,1536,1024,768,512", default="768,1024,1536,2048,1536,1024,768",
help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.", help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.",
) )
@ -142,7 +142,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=str, type=str,
default="256,256,384,512,384,256,256", default="256,384,512,768,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list." help="Embedding dimension in encoder stacks: a single int or comma-separated list."
) )
@ -156,7 +156,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--value-head-dim", "--value-head-dim",
type=str, type=str,
default="12", default="16",
help="Value dimension per head in encoder stacks: a single int or comma-separated list." help="Value dimension per head in encoder stacks: a single int or comma-separated list."
) )
@ -437,7 +437,7 @@ def get_params() -> AttributeDict:
"warm_step": 2000, "warm_step": 2000,
"env_info": get_env_info(), "env_info": get_env_info(),
"bytes_per_segment": 2048, "bytes_per_segment": 2048,
"batch_size": 18, "batch_size": 16,
"train_file_list": "train.txt", "train_file_list": "train.txt",
"valid_file_list": "valid.txt", "valid_file_list": "valid.txt",
"num_workers": 4, "num_workers": 4,