diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 9ce4f69da..d94a2f725 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -153,8 +153,8 @@ def get_params() -> AttributeDict: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> None: """Load checkpoint from file. @@ -198,6 +198,8 @@ def load_checkpoint_if_available( for k in keys: params[k] = saved_params[k] + return saved_params + def save_checkpoint( params: AttributeDict, @@ -485,6 +487,9 @@ def run(rank, world_size, args): num_classes=max_phone_id + 1, # +1 for the blank symbol subsampling_factor=params.subsampling_factor, ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + model.to(device) if world_size > 1: model = DDP(model, device_ids=[rank]) @@ -496,9 +501,8 @@ def run(rank, world_size, args): ) scheduler = StepLR(optimizer, step_size=8, gamma=0.1) - load_checkpoint_if_available( - params=params, model=model, optimizer=optimizer - ) + optimizer.load_state_dict(checkpoints["optimizer"]) + scheduler.load_state_dict(checkpoints["scheduler"]) librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 3dc1d9436..e45df4fe4 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -94,7 +94,7 @@ def load_checkpoint( s = checkpoint[name] if obj and s: obj.load_state_dict(s) - checkpoint.pop(name) + checkpoint.pop(name) load("optimizer", optimizer) load("scheduler", scheduler)