diff --git a/icefall/utils.py b/icefall/utils.py index 0feff9dc8..3e4b19e04 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -419,7 +419,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: def store_transcripts( - filename: Pathlike, texts: Iterable[Tuple[str, str, str]] + filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False ) -> None: """Save predicted results and reference transcripts to a file. @@ -436,6 +436,9 @@ def store_transcripts( """ with open(filename, "w") as f: for cut_id, ref, hyp in texts: + if char_level: + ref = list("".join(ref)) + hyp = list("".join(hyp)) print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) @@ -493,6 +496,7 @@ def write_error_stats( test_set_name: str, results: List[Tuple[str, str]], enable_log: bool = True, + compute_CER: bool = False, ) -> float: """Write statistics based on predicted results and reference transcripts. @@ -520,7 +524,7 @@ def write_error_stats( The reference word `SIR` is missing in the predicted results (a deletion error). results: - An iterable of tuples. The first element is the cur_id, the second is + An iterable of tuples. The first element is the cut_id, the second is the reference transcript and the third element is the predicted result. enable_log: If True, also print detailed WER to the console. @@ -537,6 +541,14 @@ def write_error_stats( words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) num_corr = 0 ERR = "*" + + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR) for ref_word, hyp_word in ali: @@ -612,7 +624,9 @@ def write_error_stats( f"{cut_id}:\t" + " ".join( ( - ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + ref_word + if ref_word == hyp_word + else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali ) ), @@ -622,7 +636,9 @@ def write_error_stats( print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) - for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + for count, (ref, hyp) in sorted( + [(v, k) for k, v in subs.items()], reverse=True + ): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) @@ -636,7 +652,9 @@ def write_error_stats( print(f"{count} {hyp}", file=f) print("", file=f) - print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + print( + "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f + ) for _, word, counts in sorted( [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True ):