mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Add start-batch option for RNNLM training (#1161)
* Add start-batch option for RNNLM training * Also set epoch * Skip batches on load
This commit is contained in:
parent
9009d028a0
commit
eca0202632
@ -99,6 +99,15 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-batch",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --start-epoch is ignored and
|
||||||
|
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -242,7 +251,9 @@ def load_checkpoint_if_available(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Load checkpoint from file.
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
If params.start_epoch is positive, it will load the checkpoint from
|
If params.start_batch is positive, it will load the checkpoint from
|
||||||
|
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||||
|
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||||
|
|
||||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||||
@ -261,10 +272,14 @@ def load_checkpoint_if_available(
|
|||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
if params.start_epoch <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
if params.start_batch > 0:
|
||||||
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||||
|
elif params.start_epoch > 1:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
logging.info(f"Loading checkpoint: {filename}")
|
logging.info(f"Loading checkpoint: {filename}")
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename,
|
||||||
@ -283,6 +298,13 @@ def load_checkpoint_if_available(
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
params[k] = saved_params[k]
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
if params.start_batch > 0:
|
||||||
|
if "cur_epoch" in saved_params:
|
||||||
|
params["start_epoch"] = saved_params["cur_epoch"]
|
||||||
|
|
||||||
|
if "cur_batch_idx" in saved_params:
|
||||||
|
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
@ -438,7 +460,14 @@ def train_one_epoch(
|
|||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
|
||||||
|
if batch_idx < cur_batch_idx:
|
||||||
|
continue
|
||||||
|
cur_batch_idx = batch_idx
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
x, y, sentence_lengths = batch
|
x, y, sentence_lengths = batch
|
||||||
batch_size = x.size(0)
|
batch_size = x.size(0)
|
||||||
@ -463,6 +492,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train > 0
|
params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
):
|
):
|
||||||
|
params.cur_batch_idx = batch_idx
|
||||||
save_checkpoint_with_global_batch_idx(
|
save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
global_batch_idx=params.batch_idx_train,
|
||||||
@ -471,6 +501,7 @@ def train_one_epoch(
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
del params.cur_batch_idx
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
# Note: "frames" here means "num_tokens"
|
# Note: "frames" here means "num_tokens"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user