Fix loading checkpoints saved by previous code.

This commit is contained in:
Fangjun Kuang 2022-03-23 12:21:11 +08:00
parent ed2d2932ef
commit 4fe0c0dca2

View File

@ -392,11 +392,13 @@ def load_checkpoint_if_available(
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
"cur_batch_idx",
]
for k in keys:
params[k] = saved_params[k]
params[k] = saved_params.get(k, 0)
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
return saved_params
@ -545,7 +547,6 @@ def train_one_epoch(
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
start_batch: int = 0,
) -> None:
"""Train the model for one epoch.
@ -571,8 +572,6 @@ def train_one_epoch(
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
start_batch:
If not zero, it starts from this batch for training.
"""
model.train()
@ -617,7 +616,13 @@ def train_one_epoch(
else:
optimizer.step()
for batch_idx, batch in enumerate(train_dl, start=start_batch):
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -646,6 +651,7 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -655,6 +661,7 @@ def train_one_epoch(
sampler=train_dl.sampler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -793,10 +800,8 @@ def run(rank, world_size, args):
if checkpoints and "sampler" in checkpoints:
sampler_state_dict = checkpoints["sampler"]
start_batch = sampler_state_dict["diagnostics"]["num_kept_batches"]
else:
sampler_state_dict = None
start_batch = 0
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
@ -840,9 +845,7 @@ def run(rank, world_size, args):
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
start_batch=start_batch,
)
start_batch = 0
save_checkpoint(
params=params,