diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 31241ef2b..a772af123 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index dbfc2863c..7048859f9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -105,17 +105,23 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def add_rep_arguments(parser: argparse.ArgumentParser): parser.add_argument( - "--decode-interval", + "--wandb", + type=bool, + default=False, + help="Use wandb for MLOps", + ) + parser.add_argument( + "--accum-grads", type=int, - default=200, - help="decode interval", + default=1, + help="accum-grad num.", ) parser.add_argument( - "--encoder-dim", - type=int, - default=768, - help="encoder embedding dimension", + "--multi-optim", + type=bool, + default=False, + help="use sperate optimizer (enc / dec)", ) parser.add_argument( @@ -132,41 +138,41 @@ def add_rep_arguments(parser: argparse.ArgumentParser): help="The initial learning rate. This value should not need to be changed.", ) - parser.add_argument( - "--multi-optim", - type=bool, - default=False, - help="use sperate optimizer (enc / dec)", - ) - parser.add_argument( - "--accum-grads", - type=int, - default=1, - help="accum-grad num.", - ) parser.add_argument( "--encoder-type", type=str, default='d2v', help="Type of encoder (e.g. conformer, w2v, d2v...", ) + parser.add_argument( - "--additional-block", - type=bool, - default=False, + "--encoder-dim", + type=int, + default=768, + help="encoder embedding dimension", ) + parser.add_argument( "--freeze-finetune-updates", type=int, default=0 ) + parser.add_argument( - "--wandb", + "--additional-block", type=bool, default=False, - help="Use wandb for MLOps", ) + parser.add_argument( + "--decode-interval", + type=int, + default=200, + help="decode interval", + ) + + + def add_model_arguments(parser: argparse.ArgumentParser):