remove stats

This commit is contained in:
root 2025-06-03 00:48:39 -07:00
parent 5becf6927d
commit 80677a55f8

View File

@ -430,46 +430,4 @@ def write_error_stats(
hyp_count = corr + hyp_sub + ins hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate) return float(tot_err_rate)
def get_dataset_statistics(dataset, save_filename):
speech_token_lengths = []
text_lengths = []
for item in tqdm(dataset):
if 'custom' not in item:
speech_token = item["code"]
text = item["text"]
else:
speech_token = item["custom"]["speech_token"]
text = item["supervisions"][0]["text"]
speech_token_lengths.append(len(speech_token))
text_lengths.append(len(text))
speech_token_length_array = np.array(speech_token_lengths)
text_length_array = np.array(text_lengths)
# 计算并存储统计指标
def get_length_stats(lengths_array):
length_stats = []
length_stats.append(["count", f"{len(lengths_array)}"]) # 总数
length_stats.append(["mean", f"{np.mean(lengths_array):.1f}"])
length_stats.append(["std", f"{np.std(lengths_array):.1f}"])
length_stats.append(["min", f"{np.min(lengths_array):.1f}"])
length_stats.append(["25%", f"{np.percentile(lengths_array, 25):.1f}"])
length_stats.append(["50% (median)", f"{np.median(lengths_array):.1f}"]) # median 和 50% percentile 是一样的
length_stats.append(["75%", f"{np.percentile(lengths_array, 75):.1f}"])
length_stats.append(["99%", f"{np.percentile(lengths_array, 99):.1f}"])
length_stats.append(["99.5%", f"{np.percentile(lengths_array, 99.5):.1f}"])
length_stats.append(["99.9%", f"{np.percentile(lengths_array, 99.9):.1f}"])
length_stats.append(["max", f"{np.max(lengths_array):.1f}"])
return length_stats
speech_length_stats = get_length_stats(speech_token_length_array)
text_length_stats = get_length_stats(text_length_array)
with open(save_filename, "w") as f:
print("speech_tokens 长度统计指标:", file=f)
for stat_name, stat_value in speech_length_stats:
print(f"{stat_name:<15}: {stat_value}", file=f)
print("\ntext 长度统计指标:", file=f)
for stat_name, stat_value in text_length_stats:
print(f"{stat_name:<15}: {stat_value}", file=f)
return speech_token_lengths, text_lengths