from local

This commit is contained in:
dohe0342 2022-12-10 13:31:13 +09:00
parent 45ce951b8c
commit de46c7b882
2 changed files with 30 additions and 8 deletions

View File

@ -105,18 +105,40 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
def add_rep_arguments(parser: argparse.ArgumentParser): def add_rep_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--decode-interval",
type=str, type=int,
default="2,4,3,2,4", default=200,
help="Number of zipformer encoder layers, comma separated.", help="decode interval",
) )
parser.add_argument( parser.add_argument(
"--feedforward-dims", "--encoder-dim",
type=str, type=int,
default="1024,1024,2048,2048,1024", default=768,
help="Feedforward dimension of the zipformer encoder layers, comma separated.", help="encoder embedding dimension",
) )
parser.add_argument(
"--peak-enc-lr",
type=float,
default=0.0001,
help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
"--peak-dec-lr",
type=float,
default=0.001,
help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
"--multi-optim",
type=bool,
default=False,
help="use sperate optimizer (enc / dec)",
)
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(