Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-08-03 22:16:34 +08:00
parent 2be7a0a555
commit a6d9b3c9ab
3 changed files with 28 additions and 10 deletions

View File

@ -326,20 +326,31 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results) wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename)) if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"

View File

@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
logging.info("Loading G_3_gram.fst.txt") logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f: with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False) G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "G_3_gram.pt") torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
first_token_disambig_id = lexicon.token_table["#0"] first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"] first_word_disambig_id = lexicon.word_table["#0"]

View File

@ -225,7 +225,10 @@ def store_transcripts(
def write_error_stats( def write_error_stats(
f: TextIO, test_set_name: str, results: List[Tuple[str, str]] f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
) -> float: ) -> float:
"""Write statistics based on predicted results and reference transcripts. """Write statistics based on predicted results and reference transcripts.
@ -255,6 +258,9 @@ def write_error_stats(
results: results:
An iterable of tuples. The first element is the reference transcript An iterable of tuples. The first element is the reference transcript
while the second element is the predicted result. while the second element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns: Returns:
Return None. Return None.
""" """
@ -290,11 +296,12 @@ def write_error_stats(
tot_errs = sub_errs + ins_errs + del_errs tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
logging.info( if enable_log:
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " logging.info(
f"[{tot_errs} / {ref_len}, {ins_errs} ins, " f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"{del_errs} del, {sub_errs} sub ]" f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
) f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f) print(f"%WER = {tot_err_rate}", file=f)
print( print(