mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Minor fixes.
This commit is contained in:
parent
2be7a0a555
commit
a6d9b3c9ab
@ -326,20 +326,31 @@ def save_results(
|
|||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
):
|
):
|
||||||
|
if params.method == "attention-decoder":
|
||||||
|
# Set it to False since there are too many logs.
|
||||||
|
enable_log = False
|
||||||
|
else:
|
||||||
|
enable_log = True
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
if enable_log:
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||||
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
if enable_log:
|
||||||
|
logging.info(
|
||||||
|
"Wrote detailed error stats to {}".format(errs_filename)
|
||||||
|
)
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||||
|
|||||||
@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
logging.info("Loading G_3_gram.fst.txt")
|
logging.info("Loading G_3_gram.fst.txt")
|
||||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
torch.save(G.as_dict(), "G_3_gram.pt")
|
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||||
|
|
||||||
first_token_disambig_id = lexicon.token_table["#0"]
|
first_token_disambig_id = lexicon.token_table["#0"]
|
||||||
first_word_disambig_id = lexicon.word_table["#0"]
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|||||||
@ -225,7 +225,10 @@ def store_transcripts(
|
|||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
f: TextIO, test_set_name: str, results: List[Tuple[str, str]]
|
f: TextIO,
|
||||||
|
test_set_name: str,
|
||||||
|
results: List[Tuple[str, str]],
|
||||||
|
enable_log: bool = True,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Write statistics based on predicted results and reference transcripts.
|
"""Write statistics based on predicted results and reference transcripts.
|
||||||
|
|
||||||
@ -255,6 +258,9 @@ def write_error_stats(
|
|||||||
results:
|
results:
|
||||||
An iterable of tuples. The first element is the reference transcript
|
An iterable of tuples. The first element is the reference transcript
|
||||||
while the second element is the predicted result.
|
while the second element is the predicted result.
|
||||||
|
enable_log:
|
||||||
|
If True, also print detailed WER to the console.
|
||||||
|
Otherwise, it is written only to the given file.
|
||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
@ -290,11 +296,12 @@ def write_error_stats(
|
|||||||
tot_errs = sub_errs + ins_errs + del_errs
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||||
|
|
||||||
logging.info(
|
if enable_log:
|
||||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
logging.info(
|
||||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||||
f"{del_errs} del, {sub_errs} sub ]"
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||||
)
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"%WER = {tot_err_rate}", file=f)
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
print(
|
print(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user