calculate symbol delay for (start, end) timestamps

This commit is contained in:
yaozengwei 2023-02-07 21:16:56 +08:00
parent 3e1d14b9f8
commit cca3113817
2 changed files with 70 additions and 19 deletions

View File

@ -633,7 +633,11 @@ def save_results(
)
with open(errs_filename, "w") as f:
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
with_end_time=True,
)
test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay)
@ -649,16 +653,17 @@ def save_results(
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
# sort according to the mean start symbol delay
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
delays_info = (
params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f)
print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
for key, val in test_set_delays:
print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
"{}\tmean: {}, variance: {}".format(key, val[0], val[1]),
file=f,
)
@ -669,10 +674,12 @@ def save_results(
note = ""
logging.info(s)
s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format(
test_set_name
)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note)
note = ""
logging.info(s)

View File

@ -646,9 +646,18 @@ def write_error_stats(
def write_error_stats_with_timestamps(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
results: List[
Tuple[
str,
List[str],
List[str],
List[Union[float, Tuple[float, float]]],
List[Union[float, Tuple[float, float]]],
]
],
enable_log: bool = True,
) -> Tuple[float, float, float]:
with_end_time: bool = False,
) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]:
"""Write statistics based on predicted results and reference transcripts
as well as their timestamps.
@ -681,6 +690,8 @@ def write_error_stats_with_timestamps(
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
with_end_time:
Whether use end timestamps.
Returns:
Return total word error rate and mean delay.
@ -726,7 +737,15 @@ def write_error_stats_with_timestamps(
words[ref_word][0] += 1
num_corr += 1
if has_time:
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
if with_end_time:
all_delay.append(
(
time_hyp[p_hyp][0] - time_ref[p_ref][0],
time_hyp[p_hyp][1] - time_ref[p_ref][1],
)
)
else:
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
p_hyp += 1
p_ref += 1
if has_time:
@ -738,16 +757,39 @@ def write_error_stats_with_timestamps(
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len))
mean_delay = "inf"
var_delay = "inf"
if with_end_time:
mean_delay = (float("inf"), float("inf"))
var_delay = (float("inf"), float("inf"))
else:
mean_delay = float("inf")
var_delay = float("inf")
num_delay = len(all_delay)
if num_delay > 0:
mean_delay = sum(all_delay) / num_delay
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
mean_delay = "%.3f" % mean_delay
var_delay = "%.3f" % var_delay
if with_end_time:
all_delay_start = [i[0] for i in all_delay]
mean_delay_start = sum(all_delay_start) / num_delay
var_delay_start = (
sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay
)
all_delay_end = [i[1] for i in all_delay]
mean_delay_end = sum(all_delay_end) / num_delay
var_delay_end = (
sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay
)
mean_delay = (
float("%.3f" % mean_delay_start),
float("%.3f" % mean_delay_end),
)
var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end))
else:
mean_delay = sum(all_delay) / num_delay
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
mean_delay = float("%.3f" % mean_delay)
var_delay = float("%.3f" % var_delay)
if enable_log:
logging.info(
@ -756,7 +798,8 @@ def write_error_stats_with_timestamps(
f"{del_errs} del, {sub_errs} sub ]"
)
logging.info(
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa
f"[{test_set_name}] %symbol-delay mean (s): "
f"{mean_delay}, variance: {var_delay} " # noqa
f"computed on {num_delay} correct words"
)
@ -839,7 +882,8 @@ def write_error_stats_with_timestamps(
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate), float(mean_delay), float(var_delay)
return tot_err_rate, mean_delay, var_delay
class MetricsTracker(collections.defaultdict):
@ -1661,7 +1705,7 @@ def parse_fsa_timestamps_and_texts(
frame_shift_ms: float = 10,
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
Currently it supports two case:
Currently it supports two cases:
(1) ctc-decoding, the attribtutes `labels` and `aux_labels`
are both BPE tokens. In this case, sp should be provided.
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,