calculate symbol delay for (start, end) timestamps

This commit is contained in:
yaozengwei 2023-02-07 21:16:56 +08:00
parent 3e1d14b9f8
commit cca3113817
2 changed files with 70 additions and 19 deletions

View File

@ -633,7 +633,11 @@ def save_results(
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer, mean_delay, var_delay = write_error_stats_with_timestamps( wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True f,
f"{test_set_name}-{key}",
results,
enable_log=True,
with_end_time=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay) test_set_delays[key] = (mean_delay, var_delay)
@ -649,16 +653,17 @@ def save_results(
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) print("{}\t{}".format(key, val), file=f)
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) # sort according to the mean start symbol delay
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
delays_info = ( delays_info = (
params.res_dir params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(delays_info, "w") as f: with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f) print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
for key, val in test_set_delays: for key, val in test_set_delays:
print( print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), "{}\tmean: {}, variance: {}".format(key, val[0], val[1]),
file=f, file=f,
) )
@ -669,10 +674,12 @@ def save_results(
note = "" note = ""
logging.info(s) logging.info(s)
s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format(
test_set_name
)
note = "\tbest for {}".format(test_set_name) note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays: for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note)
note = "" note = ""
logging.info(s) logging.info(s)

View File

@ -646,9 +646,18 @@ def write_error_stats(
def write_error_stats_with_timestamps( def write_error_stats_with_timestamps(
f: TextIO, f: TextIO,
test_set_name: str, test_set_name: str,
results: List[Tuple[str, List[str], List[str], List[float], List[float]]], results: List[
Tuple[
str,
List[str],
List[str],
List[Union[float, Tuple[float, float]]],
List[Union[float, Tuple[float, float]]],
]
],
enable_log: bool = True, enable_log: bool = True,
) -> Tuple[float, float, float]: with_end_time: bool = False,
) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]:
"""Write statistics based on predicted results and reference transcripts """Write statistics based on predicted results and reference transcripts
as well as their timestamps. as well as their timestamps.
@ -681,6 +690,8 @@ def write_error_stats_with_timestamps(
enable_log: enable_log:
If True, also print detailed WER to the console. If True, also print detailed WER to the console.
Otherwise, it is written only to the given file. Otherwise, it is written only to the given file.
with_end_time:
Whether use end timestamps.
Returns: Returns:
Return total word error rate and mean delay. Return total word error rate and mean delay.
@ -726,6 +737,14 @@ def write_error_stats_with_timestamps(
words[ref_word][0] += 1 words[ref_word][0] += 1
num_corr += 1 num_corr += 1
if has_time: if has_time:
if with_end_time:
all_delay.append(
(
time_hyp[p_hyp][0] - time_ref[p_ref][0],
time_hyp[p_hyp][1] - time_ref[p_ref][1],
)
)
else:
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
p_hyp += 1 p_hyp += 1
p_ref += 1 p_ref += 1
@ -738,16 +757,39 @@ def write_error_stats_with_timestamps(
ins_errs = sum(ins.values()) ins_errs = sum(ins.values())
del_errs = sum(dels.values()) del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len))
mean_delay = "inf" if with_end_time:
var_delay = "inf" mean_delay = (float("inf"), float("inf"))
var_delay = (float("inf"), float("inf"))
else:
mean_delay = float("inf")
var_delay = float("inf")
num_delay = len(all_delay) num_delay = len(all_delay)
if num_delay > 0: if num_delay > 0:
if with_end_time:
all_delay_start = [i[0] for i in all_delay]
mean_delay_start = sum(all_delay_start) / num_delay
var_delay_start = (
sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay
)
all_delay_end = [i[1] for i in all_delay]
mean_delay_end = sum(all_delay_end) / num_delay
var_delay_end = (
sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay
)
mean_delay = (
float("%.3f" % mean_delay_start),
float("%.3f" % mean_delay_end),
)
var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end))
else:
mean_delay = sum(all_delay) / num_delay mean_delay = sum(all_delay) / num_delay
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
mean_delay = "%.3f" % mean_delay mean_delay = float("%.3f" % mean_delay)
var_delay = "%.3f" % var_delay var_delay = float("%.3f" % var_delay)
if enable_log: if enable_log:
logging.info( logging.info(
@ -756,7 +798,8 @@ def write_error_stats_with_timestamps(
f"{del_errs} del, {sub_errs} sub ]" f"{del_errs} del, {sub_errs} sub ]"
) )
logging.info( logging.info(
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa f"[{test_set_name}] %symbol-delay mean (s): "
f"{mean_delay}, variance: {var_delay} " # noqa
f"computed on {num_delay} correct words" f"computed on {num_delay} correct words"
) )
@ -839,7 +882,8 @@ def write_error_stats_with_timestamps(
hyp_count = corr + hyp_sub + ins hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate), float(mean_delay), float(var_delay)
return tot_err_rate, mean_delay, var_delay
class MetricsTracker(collections.defaultdict): class MetricsTracker(collections.defaultdict):
@ -1661,7 +1705,7 @@ def parse_fsa_timestamps_and_texts(
frame_shift_ms: float = 10, frame_shift_ms: float = 10,
) -> Tuple[List[Tuple[float, float]], List[List[str]]]: ) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
"""Parse timestamps (in seconds) and texts for given decoded fsa paths. """Parse timestamps (in seconds) and texts for given decoded fsa paths.
Currently it supports two case: Currently it supports two cases:
(1) ctc-decoding, the attribtutes `labels` and `aux_labels` (1) ctc-decoding, the attribtutes `labels` and `aux_labels`
are both BPE tokens. In this case, sp should be provided. are both BPE tokens. In this case, sp should be provided.
(2) HLG-based 1best, the attribtute `labels` is the prediction unit, (2) HLG-based 1best, the attribtute `labels` is the prediction unit,