From d19f4bb0098266a655fb4de4342df5e04568af5e Mon Sep 17 00:00:00 2001 From: Nickolay Shmyrev Date: Mon, 3 Jul 2023 17:58:14 +0200 Subject: [PATCH] Add start-batch option for RNNLM training --- icefall/rnn_lm/train.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 0f0887859..6ff4f16ec 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -99,6 +99,15 @@ def get_parser(): """, ) + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -242,7 +251,9 @@ def load_checkpoint_if_available( ) -> None: """Load checkpoint from file. - If params.start_epoch is positive, it will load the checkpoint from + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Otherwise, this function does nothing. Apart from loading state dict for `model`, `optimizer` and `scheduler`, @@ -261,10 +272,14 @@ def load_checkpoint_if_available( Returns: Return None. """ - if params.start_epoch <= 0: + + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: return - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" logging.info(f"Loading checkpoint: {filename}") saved_params = load_checkpoint( filename,