Updated decode.py to obtain WERs for subsets.

This commit is contained in:
jinzr 2023-06-28 12:44:41 +08:00
parent 11faddc830
commit f85b95e73b

View File

@ -581,39 +581,52 @@ def save_results(
enable_log = False enable_log = False
else: else:
enable_log = True enable_log = True
test_set_wers = dict() if test_set_name == "test-eval2000":
for key, results in results_dict.items(): subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"}
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" elif test_set_name == "test-rt03":
results = sorted(results) subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"}
store_transcripts(filename=recog_path, texts=results) else:
if enable_log: raise NotImplementedError(f"No implementation for testset {test_set_name}")
logging.info(f"The transcripts are stored in {recog_path}") for subset, prefix in subsets.items():
test_set_wers = dict()
# The following prints out WERs, per-word error statistics and aligned for key, results in results_dict.items():
# ref/hyp pairs. recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt"
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" results = (
with open(errs_filename, "w") as f: sorted(list(filter(lambda x: x[0].startswith(prefix), results)))
wer = write_error_stats( if subset != "avg"
f, f"{test_set_name}-{key}", results, enable_log=enable_log else sorted(results)
) )
test_set_wers[key] = wer store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
if enable_log: # The following prints out WERs, per-word error statistics and aligned
logging.info("Wrote detailed error stats to {}".format(errs_filename)) # ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{subset}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) if enable_log:
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" logging.info("Wrote detailed error stats to {}".format(errs_filename))
with open(errs_info, "w") as f:
print("settings\tWER", file=f) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{subset}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}-{}, WER of different settings are:\n".format(
test_set_name, subset
)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) s += "{}\t{}{}\n".format(key, val, note)
note = ""
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) logging.info(s)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad() @torch.no_grad()