From f85b95e73b31b6ef000010ed70373244484338e7 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 28 Jun 2023 12:44:41 +0800 Subject: [PATCH] Updated decode.py to obtain WERs for subsets. --- egs/swbd/ASR/conformer_ctc/decode.py | 71 ++++++++++++++++------------ 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index f89ad17e0..05d07f44f 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -581,39 +581,52 @@ def save_results( enable_log = False else: enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=enable_log + if test_set_name == "test-eval2000": + subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} + elif test_set_name == "test-rt03": + subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} + else: + raise NotImplementedError(f"No implementation for testset {test_set_name}") + for subset, prefix in subsets.items(): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" + results = ( + sorted(list(filter(lambda x: x[0].startswith(prefix), results))) + if subset != "avg" + else sorted(results) ) - test_set_wers[key] = wer + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{subset}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{subset}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}-{}, WER of different settings are:\n".format( + test_set_name, subset + ) + note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) @torch.no_grad()