Fix checkpoint-writing

This commit is contained in:
Daniel Povey 2022-04-05 11:19:40 +08:00
parent 47d49f29d7
commit 1548cc7462

View File

@ -376,6 +376,7 @@ def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
@ -395,6 +396,8 @@ def load_checkpoint_if_available(
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The scheduler that we are using.
Returns:
Return a dict containing previously saved training info.
"""
@ -411,6 +414,7 @@ def load_checkpoint_if_available(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
@ -784,6 +788,7 @@ def run(rank, world_size, args):
gamma=0.5)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
@ -792,7 +797,6 @@ def run(rank, world_size, args):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
@ -881,6 +885,7 @@ def run(rank, world_size, args):
params=params,
model=model,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
rank=rank,
)