diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3b1d34757..c5ae3ad7d 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -640,7 +640,8 @@ def main(): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index d9d019743..f44164f44 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -457,7 +457,8 @@ def main(): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index b8a628f4e..d6ddbf515 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -120,22 +120,26 @@ def load_checkpoint( return checkpoint -def average_checkpoints(filenames: List[Path]) -> dict: +def average_checkpoints( + filenames: List[Path], device: torch.device = torch.device("cpu") +) -> dict: """Average a list of checkpoints. Args: filenames: Filenames of the checkpoints to be averaged. We assume all checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. Returns: Return a dict (i.e., state_dict) which is the average of all model state dicts contained in the checkpoints. """ n = len(filenames) - avg = torch.load(filenames[0], map_location="cpu")["model"] + avg = torch.load(filenames[0], map_location=device)["model"] for i in range(1, n): - state_dict = torch.load(filenames[i], map_location="cpu")["model"] + state_dict = torch.load(filenames[i], map_location=device)["model"] for k in avg: avg[k] += state_dict[k]