diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 6b0ae0931..4f22a6f6c 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -346,8 +346,7 @@ def decode_dataset( if batch_idx % 100 == 0: logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts} " + f"batch {batch_idx}, cuts processed until now is {num_cuts}" ) return results @@ -430,7 +429,9 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -461,7 +462,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 137fa795c..9a1aad579 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -236,7 +236,6 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -264,9 +263,7 @@ def decode_dataset( if batch_idx % 100 == 0: logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_idx}, cuts processed until now is {num_cuts}" ) return results @@ -328,7 +325,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt")) + HLG = k2.Fsa.from_dict( + torch.load("data/lang_phone/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -355,7 +354,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring":