mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
remove stats
This commit is contained in:
parent
5becf6927d
commit
80677a55f8
@ -431,45 +431,3 @@ def write_error_stats(
|
|||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_statistics(dataset, save_filename):
|
|
||||||
speech_token_lengths = []
|
|
||||||
text_lengths = []
|
|
||||||
for item in tqdm(dataset):
|
|
||||||
if 'custom' not in item:
|
|
||||||
speech_token = item["code"]
|
|
||||||
text = item["text"]
|
|
||||||
else:
|
|
||||||
speech_token = item["custom"]["speech_token"]
|
|
||||||
text = item["supervisions"][0]["text"]
|
|
||||||
speech_token_lengths.append(len(speech_token))
|
|
||||||
text_lengths.append(len(text))
|
|
||||||
speech_token_length_array = np.array(speech_token_lengths)
|
|
||||||
text_length_array = np.array(text_lengths)
|
|
||||||
# 计算并存储统计指标
|
|
||||||
def get_length_stats(lengths_array):
|
|
||||||
length_stats = []
|
|
||||||
length_stats.append(["count", f"{len(lengths_array)}"]) # 总数
|
|
||||||
length_stats.append(["mean", f"{np.mean(lengths_array):.1f}"])
|
|
||||||
length_stats.append(["std", f"{np.std(lengths_array):.1f}"])
|
|
||||||
length_stats.append(["min", f"{np.min(lengths_array):.1f}"])
|
|
||||||
length_stats.append(["25%", f"{np.percentile(lengths_array, 25):.1f}"])
|
|
||||||
length_stats.append(["50% (median)", f"{np.median(lengths_array):.1f}"]) # median 和 50% percentile 是一样的
|
|
||||||
length_stats.append(["75%", f"{np.percentile(lengths_array, 75):.1f}"])
|
|
||||||
length_stats.append(["99%", f"{np.percentile(lengths_array, 99):.1f}"])
|
|
||||||
length_stats.append(["99.5%", f"{np.percentile(lengths_array, 99.5):.1f}"])
|
|
||||||
length_stats.append(["99.9%", f"{np.percentile(lengths_array, 99.9):.1f}"])
|
|
||||||
length_stats.append(["max", f"{np.max(lengths_array):.1f}"])
|
|
||||||
return length_stats
|
|
||||||
speech_length_stats = get_length_stats(speech_token_length_array)
|
|
||||||
text_length_stats = get_length_stats(text_length_array)
|
|
||||||
with open(save_filename, "w") as f:
|
|
||||||
print("speech_tokens 长度统计指标:", file=f)
|
|
||||||
for stat_name, stat_value in speech_length_stats:
|
|
||||||
print(f"{stat_name:<15}: {stat_value}", file=f)
|
|
||||||
print("\ntext 长度统计指标:", file=f)
|
|
||||||
for stat_name, stat_value in text_length_stats:
|
|
||||||
print(f"{stat_name:<15}: {stat_value}", file=f)
|
|
||||||
|
|
||||||
return speech_token_lengths, text_lengths
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user