mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix checkpoint-writing
This commit is contained in:
parent
47d49f29d7
commit
1548cc7462
@ -376,6 +376,7 @@ def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
@ -395,6 +396,8 @@ def load_checkpoint_if_available(
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The scheduler that we are using.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
@ -411,6 +414,7 @@ def load_checkpoint_if_available(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
@ -784,6 +788,7 @@ def run(rank, world_size, args):
|
||||
gamma=0.5)
|
||||
|
||||
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
@ -792,7 +797,6 @@ def run(rank, world_size, args):
|
||||
logging.info("Loading scheduler state dict")
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
@ -881,6 +885,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user