From cca31138177778734243013a663c475180582e75 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 7 Feb 2023 21:16:56 +0800 Subject: [PATCH] calculate symbol delay for (start, end) timestamps --- egs/librispeech/ASR/conformer_ctc3/decode.py | 19 ++++-- icefall/utils.py | 70 ++++++++++++++++---- 2 files changed, 70 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 33d04650f..3b24ad597 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -633,7 +633,11 @@ def save_results( ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( - f, f"{test_set_name}-{key}", results, enable_log=True + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + with_end_time=True, ) test_set_wers[key] = wer test_set_delays[key] = (mean_delay, var_delay) @@ -649,16 +653,17 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) - test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + # sort according to the mean start symbol delay + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) delays_info = ( params.res_dir / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(delays_info, "w") as f: - print("settings\tsymbol-delay", file=f) + print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) for key, val in test_set_delays: print( - "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + "{}\tmean: {}, variance: {}".format(key, val[0], val[1]), file=f, ) @@ -669,10 +674,12 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: - s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note) note = "" logging.info(s) diff --git a/icefall/utils.py b/icefall/utils.py index c89356c58..729c15ed9 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -646,9 +646,18 @@ def write_error_stats( def write_error_stats_with_timestamps( f: TextIO, test_set_name: str, - results: List[Tuple[str, List[str], List[str], List[float], List[float]]], + results: List[ + Tuple[ + str, + List[str], + List[str], + List[Union[float, Tuple[float, float]]], + List[Union[float, Tuple[float, float]]], + ] + ], enable_log: bool = True, -) -> Tuple[float, float, float]: + with_end_time: bool = False, +) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]: """Write statistics based on predicted results and reference transcripts as well as their timestamps. @@ -681,6 +690,8 @@ def write_error_stats_with_timestamps( enable_log: If True, also print detailed WER to the console. Otherwise, it is written only to the given file. + with_end_time: + Whether use end timestamps. Returns: Return total word error rate and mean delay. @@ -726,7 +737,15 @@ def write_error_stats_with_timestamps( words[ref_word][0] += 1 num_corr += 1 if has_time: - all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) + if with_end_time: + all_delay.append( + ( + time_hyp[p_hyp][0] - time_ref[p_ref][0], + time_hyp[p_hyp][1] - time_ref[p_ref][1], + ) + ) + else: + all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) p_hyp += 1 p_ref += 1 if has_time: @@ -738,16 +757,39 @@ def write_error_stats_with_timestamps( ins_errs = sum(ins.values()) del_errs = sum(dels.values()) tot_errs = sub_errs + ins_errs + del_errs - tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len)) - mean_delay = "inf" - var_delay = "inf" + if with_end_time: + mean_delay = (float("inf"), float("inf")) + var_delay = (float("inf"), float("inf")) + else: + mean_delay = float("inf") + var_delay = float("inf") num_delay = len(all_delay) if num_delay > 0: - mean_delay = sum(all_delay) / num_delay - var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay - mean_delay = "%.3f" % mean_delay - var_delay = "%.3f" % var_delay + if with_end_time: + all_delay_start = [i[0] for i in all_delay] + mean_delay_start = sum(all_delay_start) / num_delay + var_delay_start = ( + sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay + ) + + all_delay_end = [i[1] for i in all_delay] + mean_delay_end = sum(all_delay_end) / num_delay + var_delay_end = ( + sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay + ) + + mean_delay = ( + float("%.3f" % mean_delay_start), + float("%.3f" % mean_delay_end), + ) + var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end)) + else: + mean_delay = sum(all_delay) / num_delay + var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay + mean_delay = float("%.3f" % mean_delay) + var_delay = float("%.3f" % var_delay) if enable_log: logging.info( @@ -756,7 +798,8 @@ def write_error_stats_with_timestamps( f"{del_errs} del, {sub_errs} sub ]" ) logging.info( - f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa + f"[{test_set_name}] %symbol-delay mean (s): " + f"{mean_delay}, variance: {var_delay} " # noqa f"computed on {num_delay} correct words" ) @@ -839,7 +882,8 @@ def write_error_stats_with_timestamps( hyp_count = corr + hyp_sub + ins print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) - return float(tot_err_rate), float(mean_delay), float(var_delay) + + return tot_err_rate, mean_delay, var_delay class MetricsTracker(collections.defaultdict): @@ -1661,7 +1705,7 @@ def parse_fsa_timestamps_and_texts( frame_shift_ms: float = 10, ) -> Tuple[List[Tuple[float, float]], List[List[str]]]: """Parse timestamps (in seconds) and texts for given decoded fsa paths. - Currently it supports two case: + Currently it supports two cases: (1) ctc-decoding, the attribtutes `labels` and `aux_labels` are both BPE tokens. In this case, sp should be provided. (2) HLG-based 1best, the attribtute `labels` is the prediction unit,