Add start-batch option for RNNLM training (#1161)

* Add start-batch option for RNNLM training

* Also set epoch

* Skip batches on load
This commit is contained in:
Nickolay V. Shmyrev 2023-07-04 05:13:25 +03:00 committed by GitHub
parent 9009d028a0
commit eca0202632
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -99,6 +99,15 @@ def get_parser():
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -242,7 +251,9 @@ def load_checkpoint_if_available(
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
@ -261,10 +272,14 @@ def load_checkpoint_if_available(
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
@ -283,6 +298,13 @@ def load_checkpoint_if_available(
for k in keys:
params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -438,7 +460,14 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
@ -463,6 +492,7 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -471,6 +501,7 @@ def train_one_epoch(
optimizer=optimizer,
rank=rank,
)
del params.cur_batch_idx
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"