From a6d9b3c9ab625154fe37ed26ca468a14751cecc7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 3 Aug 2021 22:16:34 +0800 Subject: [PATCH] Minor fixes. --- egs/librispeech/ASR/conformer_ctc/decode.py | 17 ++++++++++++++--- egs/librispeech/ASR/local/compile_hlg.py | 2 +- icefall/utils.py | 19 +++++++++++++------ 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 0611814f6..889a0a474 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -326,20 +326,31 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index c02fb7c0d..b30402161 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Loading G_3_gram.fst.txt") with open("data/lm/G_3_gram.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "G_3_gram.pt") + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") first_token_disambig_id = lexicon.token_table["#0"] first_word_disambig_id = lexicon.word_table["#0"] diff --git a/icefall/utils.py b/icefall/utils.py index 1f2cf95f3..3d48badfe 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -225,7 +225,10 @@ def store_transcripts( def write_error_stats( - f: TextIO, test_set_name: str, results: List[Tuple[str, str]] + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, ) -> float: """Write statistics based on predicted results and reference transcripts. @@ -255,6 +258,9 @@ def write_error_stats( results: An iterable of tuples. The first element is the reference transcript while the second element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. Returns: Return None. """ @@ -290,11 +296,12 @@ def write_error_stats( tot_errs = sub_errs + ins_errs + del_errs tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) - logging.info( - f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " - f"[{tot_errs} / {ref_len}, {ins_errs} ins, " - f"{del_errs} del, {sub_errs} sub ]" - ) + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) print(f"%WER = {tot_err_rate}", file=f) print(