from local
This commit is contained in:
parent
ee376e5e25
commit
45ce951b8c
Binary file not shown.
@ -103,6 +103,21 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
|||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
|
|
||||||
|
|
||||||
|
def add_rep_arguments(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-encoder-layers",
|
||||||
|
type=str,
|
||||||
|
default="2,4,3,2,4",
|
||||||
|
help="Number of zipformer encoder layers, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feedforward-dims",
|
||||||
|
type=str,
|
||||||
|
default="1024,1024,2048,2048,1024",
|
||||||
|
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
@ -124,14 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
default="8,8,8,8,8",
|
default="8,8,8,8,8",
|
||||||
help="Number of attention heads in the zipformer encoder layers.",
|
help="Number of attention heads in the zipformer encoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-dim",
|
|
||||||
type=int,
|
|
||||||
default=768,
|
|
||||||
help="Encoder embedding dimension",
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dims",
|
"--encoder-dims",
|
||||||
type=str,
|
type=str,
|
||||||
@ -169,7 +177,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
default="31,31,31,31,31",
|
default="31,31,31,31,31",
|
||||||
help="Sizes of kernels in convolution modules",
|
help="Sizes of kernels in convolution modules",
|
||||||
)
|
)
|
||||||
'''
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder-dim",
|
"--decoder-dim",
|
||||||
type=int,
|
type=int,
|
||||||
@ -206,13 +214,6 @@ def get_parser():
|
|||||||
default=12354,
|
default=12354,
|
||||||
help="Master port to use for DDP training.",
|
help="Master port to use for DDP training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--wandb",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Should various information be logged in wandb.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tensorboard",
|
"--tensorboard",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user