From a87a39da8c6769237872181ee4d75b4cd021dc9b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 19 Aug 2021 14:52:01 +0800 Subject: [PATCH] Fix an error in displaying decoding process. --- egs/librispeech/ASR/conformer_ctc/decode.py | 11 +++++------ egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 11 +++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 889a0a474..0722cd582 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -284,7 +284,6 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -314,9 +313,7 @@ def decode_dataset( if batch_idx % 100 == 0: logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_idx}, cuts processed until now is {num_cuts}" ) return results @@ -399,7 +396,9 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -430,7 +429,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 137fa795c..9a1aad579 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -236,7 +236,6 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -264,9 +263,7 @@ def decode_dataset( if batch_idx % 100 == 0: logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_idx}, cuts processed until now is {num_cuts}" ) return results @@ -328,7 +325,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt")) + HLG = k2.Fsa.from_dict( + torch.load("data/lang_phone/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -355,7 +354,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring":