support computing CER, writing character level transcript

This commit is contained in:
marcoyang1998 2023-09-14 18:31:18 +08:00
parent 81af525de4
commit f9ef9f38eb

View File

@ -419,7 +419,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
def store_transcripts( def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]] filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None: ) -> None:
"""Save predicted results and reference transcripts to a file. """Save predicted results and reference transcripts to a file.
@ -436,6 +436,9 @@ def store_transcripts(
""" """
with open(filename, "w") as f: with open(filename, "w") as f:
for cut_id, ref, hyp in texts: for cut_id, ref, hyp in texts:
if char_level:
ref = list("".join(ref))
hyp = list("".join(hyp))
print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f)
@ -493,6 +496,7 @@ def write_error_stats(
test_set_name: str, test_set_name: str,
results: List[Tuple[str, str]], results: List[Tuple[str, str]],
enable_log: bool = True, enable_log: bool = True,
compute_CER: bool = False,
) -> float: ) -> float:
"""Write statistics based on predicted results and reference transcripts. """Write statistics based on predicted results and reference transcripts.
@ -520,7 +524,7 @@ def write_error_stats(
The reference word `SIR` is missing in the predicted The reference word `SIR` is missing in the predicted
results (a deletion error). results (a deletion error).
results: results:
An iterable of tuples. The first element is the cur_id, the second is An iterable of tuples. The first element is the cut_id, the second is
the reference transcript and the third element is the predicted result. the reference transcript and the third element is the predicted result.
enable_log: enable_log:
If True, also print detailed WER to the console. If True, also print detailed WER to the console.
@ -537,6 +541,14 @@ def write_error_stats(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0 num_corr = 0
ERR = "*" ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR) ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali: for ref_word, hyp_word in ali:
@ -612,7 +624,9 @@ def write_error_stats(
f"{cut_id}:\t" f"{cut_id}:\t"
+ " ".join( + " ".join(
( (
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali for ref_word, hyp_word in ali
) )
), ),
@ -622,7 +636,9 @@ def write_error_stats(
print("", file=f) print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): for count, (ref, hyp) in sorted(
[(v, k) for k, v in subs.items()], reverse=True
):
print(f"{count} {ref} -> {hyp}", file=f) print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f) print("", file=f)
@ -636,7 +652,9 @@ def write_error_stats(
print(f"{count} {hyp}", file=f) print(f"{count} {hyp}", file=f)
print("", file=f) print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) print(
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
for _, word, counts in sorted( for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
): ):