From d3101fb0053f4cdba87bf96c06a330cb142a4a14 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 26 Jul 2021 08:08:14 +0800 Subject: [PATCH] Fix loading checkpoint in DDP training. --- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 14 +++++++++----- icefall/checkpoint.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 9ce4f69da..d94a2f725 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -153,8 +153,8 @@ def get_params() -> AttributeDict: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> None: """Load checkpoint from file. @@ -198,6 +198,8 @@ def load_checkpoint_if_available( for k in keys: params[k] = saved_params[k] + return saved_params + def save_checkpoint( params: AttributeDict, @@ -485,6 +487,9 @@ def run(rank, world_size, args): num_classes=max_phone_id + 1, # +1 for the blank symbol subsampling_factor=params.subsampling_factor, ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + model.to(device) if world_size > 1: model = DDP(model, device_ids=[rank]) @@ -496,9 +501,8 @@ def run(rank, world_size, args): ) scheduler = StepLR(optimizer, step_size=8, gamma=0.1) - load_checkpoint_if_available( - params=params, model=model, optimizer=optimizer - ) + optimizer.load_state_dict(checkpoints["optimizer"]) + scheduler.load_state_dict(checkpoints["scheduler"]) librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 3dc1d9436..e45df4fe4 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -94,7 +94,7 @@ def load_checkpoint( s = checkpoint[name] if obj and s: obj.load_state_dict(s) - checkpoint.pop(name) + checkpoint.pop(name) load("optimizer", optimizer) load("scheduler", scheduler)