From e65725810ce47a2c0b3933235d337ffb0f9f5b67 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 09:13:12 +0000 Subject: [PATCH] fix mmsu --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 3 +-- .../SPEECH2SPEECH/qwen_omni/client.py | 26 ++++++++++++------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 25cd79810..c974ee88f 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -242,8 +242,7 @@ fi if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then log "stage 14: Client" - # datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa) - datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval sd-qa) + datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa) for dataset in ${datasets[@]}; do # sd-qa should use usa split if [ $dataset == "sd-qa" ]; then diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py index 822d7d709..05c363979 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py @@ -4,7 +4,7 @@ import json import os import requests -from datasets import load_dataset +from datasets import concatenate_datasets, load_dataset from tqdm import tqdm @@ -31,7 +31,7 @@ def get_args(): parser.add_argument( "--split-name", type=str, - default="test", # Adjust as needed + default=None, # Adjust as needed help="Dataset split name", ) parser.add_argument( @@ -51,13 +51,21 @@ def main(): server_decode_url = f"{args.server_url}/decode" print("Loading dataset...") - - dataset = load_dataset( - args.dataset_name, - args.subset_name, - split=args.split_name, - trust_remote_code=True, - ) + if args.subset_name != "mmsu": + dataset = load_dataset( + args.dataset_name, + args.subset_name, + split=args.split_name, + trust_remote_code=True, + ) + else: + # load all splits and concatenate them + dataset = load_dataset( + args.dataset_name, + args.subset_name, + trust_remote_code=True, + ) + dataset = concatenate_datasets([dataset[subset] for subset in dataset]) print(f"Dataset loaded with {len(dataset)} samples.") print(f"Sending requests to {server_decode_url}...")