Add start-batch option for RNNLM training

This commit is contained in:
Nickolay Shmyrev 2023-07-03 17:58:14 +02:00
parent c3e23ec8d2
commit d19f4bb009

View File

@ -99,6 +99,15 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -242,7 +251,9 @@ def load_checkpoint_if_available(
) -> None: ) -> None:
"""Load checkpoint from file. """Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing. `params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`, Apart from loading state dict for `model`, `optimizer` and `scheduler`,
@ -261,10 +272,14 @@ def load_checkpoint_if_available(
Returns: Returns:
Return None. Return None.
""" """
if params.start_epoch <= 0:
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
logging.info(f"Loading checkpoint: {filename}") logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename,