From 1548cc7462a59da00f3bddad7b51166c5a0a3b09 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 11:19:40 +0800 Subject: [PATCH] Fix checkpoint-writing --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 83558a72b..c63c849c4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -376,6 +376,7 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -395,6 +396,8 @@ def load_checkpoint_if_available( The training model. optimizer: The optimizer that we are using. + scheduler: + The scheduler that we are using. Returns: Return a dict containing previously saved training info. """ @@ -411,6 +414,7 @@ def load_checkpoint_if_available( filename, model=model, optimizer=optimizer, + scheduler=scheduler, ) keys = [ @@ -784,6 +788,7 @@ def run(rank, world_size, args): gamma=0.5) + if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) @@ -792,7 +797,6 @@ def run(rank, world_size, args): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) - if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 2 ** 22 @@ -881,6 +885,7 @@ def run(rank, world_size, args): params=params, model=model, optimizer=optimizer, + scheduler=scheduler, sampler=train_dl.sampler, rank=rank, )