Update decode.py

Resolve issues with abnormal output formats and inaccurate error rates
This commit is contained in:
fenghaojin 2024-03-14 17:00:16 +08:00 committed by GitHub
parent f28c05f4f5
commit d5309db7f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -492,7 +492,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -500,7 +500,7 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer