from local

This commit is contained in:
dohe0342 2022-12-10 13:28:43 +09:00
parent ee376e5e25
commit 45ce951b8c
2 changed files with 17 additions and 16 deletions

View File

@ -103,6 +103,21 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
module.batch_count = batch_count
def add_rep_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=str,
default="2,4,3,2,4",
help="Number of zipformer encoder layers, comma separated.",
)
parser.add_argument(
"--feedforward-dims",
type=str,
default="1024,1024,2048,2048,1024",
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
)
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
@ -124,14 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="8,8,8,8,8",
help="Number of attention heads in the zipformer encoder layers.",
)
parser.add_argument(
"--encoder-dim",
type=int,
default=768,
help="Encoder embedding dimension",
)
'''
parser.add_argument(
"--encoder-dims",
type=str,
@ -169,7 +177,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="31,31,31,31,31",
help="Sizes of kernels in convolution modules",
)
'''
parser.add_argument(
"--decoder-dim",
type=int,
@ -206,13 +214,6 @@ def get_parser():
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--wandb",
type=str2bool,
default=False,
help="Should various information be logged in wandb.",
)
parser.add_argument(
"--tensorboard",