Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-08-03 22:16:34 +08:00
parent 2be7a0a555
commit a6d9b3c9ab
3 changed files with 28 additions and 10 deletions

View File

@ -326,20 +326,31 @@ def save_results(
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"

View File

@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "G_3_gram.pt")
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]

View File

@ -225,7 +225,10 @@ def store_transcripts(
def write_error_stats(
f: TextIO, test_set_name: str, results: List[Tuple[str, str]]
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
@ -255,6 +258,9 @@ def write_error_stats(
results:
An iterable of tuples. The first element is the reference transcript
while the second element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
@ -290,11 +296,12 @@ def write_error_stats(
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(