mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
support computing CER, writing character level transcript
This commit is contained in:
parent
81af525de4
commit
f9ef9f38eb
@ -419,7 +419,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
|||||||
|
|
||||||
|
|
||||||
def store_transcripts(
|
def store_transcripts(
|
||||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]]
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save predicted results and reference transcripts to a file.
|
"""Save predicted results and reference transcripts to a file.
|
||||||
|
|
||||||
@ -436,6 +436,9 @@ def store_transcripts(
|
|||||||
"""
|
"""
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
for cut_id, ref, hyp in texts:
|
for cut_id, ref, hyp in texts:
|
||||||
|
if char_level:
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
print(f"{cut_id}:\tref={ref}", file=f)
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
@ -493,6 +496,7 @@ def write_error_stats(
|
|||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results: List[Tuple[str, str]],
|
results: List[Tuple[str, str]],
|
||||||
enable_log: bool = True,
|
enable_log: bool = True,
|
||||||
|
compute_CER: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Write statistics based on predicted results and reference transcripts.
|
"""Write statistics based on predicted results and reference transcripts.
|
||||||
|
|
||||||
@ -520,7 +524,7 @@ def write_error_stats(
|
|||||||
The reference word `SIR` is missing in the predicted
|
The reference word `SIR` is missing in the predicted
|
||||||
results (a deletion error).
|
results (a deletion error).
|
||||||
results:
|
results:
|
||||||
An iterable of tuples. The first element is the cur_id, the second is
|
An iterable of tuples. The first element is the cut_id, the second is
|
||||||
the reference transcript and the third element is the predicted result.
|
the reference transcript and the third element is the predicted result.
|
||||||
enable_log:
|
enable_log:
|
||||||
If True, also print detailed WER to the console.
|
If True, also print detailed WER to the console.
|
||||||
@ -537,6 +541,14 @@ def write_error_stats(
|
|||||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
num_corr = 0
|
num_corr = 0
|
||||||
ERR = "*"
|
ERR = "*"
|
||||||
|
|
||||||
|
if compute_CER:
|
||||||
|
for i, res in enumerate(results):
|
||||||
|
cut_id, ref, hyp = res
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
results[i] = (cut_id, ref, hyp)
|
||||||
|
|
||||||
for cut_id, ref, hyp in results:
|
for cut_id, ref, hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR)
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
for ref_word, hyp_word in ali:
|
for ref_word, hyp_word in ali:
|
||||||
@ -612,7 +624,9 @@ def write_error_stats(
|
|||||||
f"{cut_id}:\t"
|
f"{cut_id}:\t"
|
||||||
+ " ".join(
|
+ " ".join(
|
||||||
(
|
(
|
||||||
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
ref_word
|
||||||
|
if ref_word == hyp_word
|
||||||
|
else f"({ref_word}->{hyp_word})"
|
||||||
for ref_word, hyp_word in ali
|
for ref_word, hyp_word in ali
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@ -622,7 +636,9 @@ def write_error_stats(
|
|||||||
print("", file=f)
|
print("", file=f)
|
||||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
for count, (ref, hyp) in sorted(
|
||||||
|
[(v, k) for k, v in subs.items()], reverse=True
|
||||||
|
):
|
||||||
print(f"{count} {ref} -> {hyp}", file=f)
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
@ -636,7 +652,9 @@ def write_error_stats(
|
|||||||
print(f"{count} {hyp}", file=f)
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
print(
|
||||||
|
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||||
|
)
|
||||||
for _, word, counts in sorted(
|
for _, word, counts in sorted(
|
||||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
):
|
):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user