diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 0ace3e9a1..d13a7bd41 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index fe422eb16..c72936741 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -103,6 +103,21 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: module.batch_count = batch_count +def add_rep_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", @@ -124,14 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="8,8,8,8,8", help="Number of attention heads in the zipformer encoder layers.", ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=768, - help="Encoder embedding dimension", - ) - ''' + parser.add_argument( "--encoder-dims", type=str, @@ -169,7 +177,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="31,31,31,31,31", help="Sizes of kernels in convolution modules", ) - ''' + parser.add_argument( "--decoder-dim", type=int, @@ -206,13 +214,6 @@ def get_parser(): default=12354, help="Master port to use for DDP training.", ) - - parser.add_argument( - "--wandb", - type=str2bool, - default=False, - help="Should various information be logged in wandb.", - ) parser.add_argument( "--tensorboard",