Fix loading checkpoints saved by previous code.

This commit is contained in:
Fangjun Kuang 2022-03-23 12:21:11 +08:00
parent ed2d2932ef
commit 4fe0c0dca2

View File

@ -392,11 +392,13 @@ def load_checkpoint_if_available(
"batch_idx_train", "batch_idx_train",
"best_train_loss", "best_train_loss",
"best_valid_loss", "best_valid_loss",
"cur_batch_idx",
] ]
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params.get(k, 0)
params["start_epoch"] = saved_params["cur_epoch"] if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
return saved_params return saved_params
@ -545,7 +547,6 @@ def train_one_epoch(
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
start_batch: int = 0,
) -> None: ) -> None:
"""Train the model for one epoch. """Train the model for one epoch.
@ -571,8 +572,6 @@ def train_one_epoch(
rank: rank:
The rank of the node in DDP training. If no DDP is used, it should The rank of the node in DDP training. If no DDP is used, it should
be set to 0. be set to 0.
start_batch:
If not zero, it starts from this batch for training.
""" """
model.train() model.train()
@ -617,7 +616,13 @@ def train_one_epoch(
else: else:
optimizer.step() optimizer.step()
for batch_idx, batch in enumerate(train_dl, start=start_batch): cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -646,6 +651,7 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -655,6 +661,7 @@ def train_one_epoch(
sampler=train_dl.sampler, sampler=train_dl.sampler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,
@ -793,10 +800,8 @@ def run(rank, world_size, args):
if checkpoints and "sampler" in checkpoints: if checkpoints and "sampler" in checkpoints:
sampler_state_dict = checkpoints["sampler"] sampler_state_dict = checkpoints["sampler"]
start_batch = sampler_state_dict["diagnostics"]["num_kept_batches"]
else: else:
sampler_state_dict = None sampler_state_dict = None
start_batch = 0
train_dl = librispeech.train_dataloaders( train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
@ -840,9 +845,7 @@ def run(rank, world_size, args):
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
start_batch=start_batch,
) )
start_batch = 0
save_checkpoint( save_checkpoint(
params=params, params=params,