From 47519b36a42690779fd52e357e0b862ef5e3aa50 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 11:32:10 +0800 Subject: [PATCH] Use GPU for averaging checkpoints if possible. --- egs/librispeech/ASR/conformer_ctc/decode.py | 3 ++- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 3 ++- icefall/checkpoint.py | 10 +++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 5a83dd39c..343eb1d60 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -638,7 +638,8 @@ def main(): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 54c2f7a6b..4de2e126e 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -455,7 +455,8 @@ def main(): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index b8a628f4e..d6ddbf515 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -120,22 +120,26 @@ def load_checkpoint( return checkpoint -def average_checkpoints(filenames: List[Path]) -> dict: +def average_checkpoints( + filenames: List[Path], device: torch.device = torch.device("cpu") +) -> dict: """Average a list of checkpoints. Args: filenames: Filenames of the checkpoints to be averaged. We assume all checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. Returns: Return a dict (i.e., state_dict) which is the average of all model state dicts contained in the checkpoints. """ n = len(filenames) - avg = torch.load(filenames[0], map_location="cpu")["model"] + avg = torch.load(filenames[0], map_location=device)["model"] for i in range(1, n): - state_dict = torch.load(filenames[i], map_location="cpu")["model"] + state_dict = torch.load(filenames[i], map_location=device)["model"] for k in avg: avg[k] += state_dict[k]