diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py index 81f7c0d5c..fad7f272c 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py @@ -430,46 +430,4 @@ def write_error_stats( hyp_count = corr + hyp_sub + ins print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) - return float(tot_err_rate) - - -def get_dataset_statistics(dataset, save_filename): - speech_token_lengths = [] - text_lengths = [] - for item in tqdm(dataset): - if 'custom' not in item: - speech_token = item["code"] - text = item["text"] - else: - speech_token = item["custom"]["speech_token"] - text = item["supervisions"][0]["text"] - speech_token_lengths.append(len(speech_token)) - text_lengths.append(len(text)) - speech_token_length_array = np.array(speech_token_lengths) - text_length_array = np.array(text_lengths) - # 计算并存储统计指标 - def get_length_stats(lengths_array): - length_stats = [] - length_stats.append(["count", f"{len(lengths_array)}"]) # 总数 - length_stats.append(["mean", f"{np.mean(lengths_array):.1f}"]) - length_stats.append(["std", f"{np.std(lengths_array):.1f}"]) - length_stats.append(["min", f"{np.min(lengths_array):.1f}"]) - length_stats.append(["25%", f"{np.percentile(lengths_array, 25):.1f}"]) - length_stats.append(["50% (median)", f"{np.median(lengths_array):.1f}"]) # median 和 50% percentile 是一样的 - length_stats.append(["75%", f"{np.percentile(lengths_array, 75):.1f}"]) - length_stats.append(["99%", f"{np.percentile(lengths_array, 99):.1f}"]) - length_stats.append(["99.5%", f"{np.percentile(lengths_array, 99.5):.1f}"]) - length_stats.append(["99.9%", f"{np.percentile(lengths_array, 99.9):.1f}"]) - length_stats.append(["max", f"{np.max(lengths_array):.1f}"]) - return length_stats - speech_length_stats = get_length_stats(speech_token_length_array) - text_length_stats = get_length_stats(text_length_array) - with open(save_filename, "w") as f: - print("speech_tokens 长度统计指标:", file=f) - for stat_name, stat_value in speech_length_stats: - print(f"{stat_name:<15}: {stat_value}", file=f) - print("\ntext 长度统计指标:", file=f) - for stat_name, stat_value in text_length_stats: - print(f"{stat_name:<15}: {stat_value}", file=f) - - return speech_token_lengths, text_lengths + return float(tot_err_rate) \ No newline at end of file