add voicebench eval

This commit is contained in:
root 2025-05-13 05:37:11 +00:00
parent 89781b9bb1
commit cbf3af31fd
6 changed files with 558 additions and 1 deletions

View File

@ -211,3 +211,49 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
--token2wav-path /workspace/CosyVoice2-0.5B \
--use-lora True
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "stage 12: Decoding EN voicebench"
exp_dir=./qwen_omni/exp_speech2speech_en_continue
torchrun --nproc_per_node=2 \
./qwen_omni/decode_dist.py \
--output-dir $exp_dir/log_voicebench \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--use-flash-attn True \
--enable-speech-output True \
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
--use-lora True --subset-name openbookqa --split-name test
fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "stage 13: Server"
exp_dir=./qwen_omni/exp_speech2speech_en_continue
python3 ./qwen_omni/server.py \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output True \
--use-lora True
fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "stage 14: Client"
# datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa)
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval sd-qa)
for dataset in ${datasets[@]}; do
# sd-qa should use usa split
if [ $dataset == "sd-qa" ]; then
split_name="usa"
else
split_name="test"
fi
echo $dataset $split_name
python3 ./qwen_omni/client.py \
--subset-name $dataset --split-name $split_name \
--output-dir test_result
done
fi

View File

@ -0,0 +1,142 @@
# client.py
import argparse
import json
import os
import requests
from datasets import load_dataset
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(description="Speech-to-Text Client")
parser.add_argument(
"--server-url",
type=str,
default="http://localhost:8000",
help="URL of the FastAPI server",
)
parser.add_argument(
"--dataset-name",
type=str,
default="hlt-lab/voicebench",
help="Hugging Face dataset name",
)
parser.add_argument(
"--subset-name",
type=str,
default="commoneval", # Adjust as needed
help="Dataset subset name",
)
parser.add_argument(
"--split-name",
type=str,
default="test", # Adjust as needed
help="Dataset split name",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="Directory to save results"
)
args = parser.parse_args()
return args
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
output_filename = os.path.join(
args.output_dir,
f"{args.subset_name}-{args.split_name}.jsonl",
)
server_decode_url = f"{args.server_url}/decode"
print("Loading dataset...")
dataset = load_dataset(
args.dataset_name,
args.subset_name,
split=args.split_name,
trust_remote_code=True,
)
print(f"Dataset loaded with {len(dataset)} samples.")
print(f"Sending requests to {server_decode_url}...")
print(f"Saving results to {output_filename}")
with open(output_filename, "w", encoding="utf-8") as outfile:
# Iterate directly over the dataset
progress_bar = tqdm(dataset, desc="Processing", unit="samples")
for item in progress_bar:
audio_info = item.get("audio")
assert (
audio_info["sampling_rate"] == 16000
), f"Sampling rate is {audio_info['sampling_rate']}, not 16khz"
# Prepare data for JSON serialization and server request
audio_array = audio_info["array"].tolist() # Convert numpy array to list
result_dict = {}
for key in item.keys():
if key != "audio":
# Ensure other fields are JSON serializable
try:
# Attempt to serialize to catch issues early (optional)
json.dumps(item[key])
result_dict[key] = item[key]
except (TypeError, OverflowError):
print(
f"Warning: Converting non-serializable key '{key}' to string."
)
result_dict[key] = str(
item[key]
) # Convert problematic types to string
payload = {
"audio": audio_array,
"sampling_rate": 16000,
}
try:
response = requests.post(server_decode_url, json=payload, timeout=60)
response.raise_for_status()
server_response = response.json()
decoded_text = server_response.get("text", "")
# Add the response to the result dictionary
result_dict["response"] = decoded_text
print(result_dict)
# Write result to JSONL file
json.dump(result_dict, outfile, ensure_ascii=False)
outfile.write("\n")
except requests.exceptions.RequestException as e:
print(f"\nError sending request for an item: {e}")
error_entry = result_dict # Use the data prepared so far
error_entry["error"] = str(e)
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
except json.JSONDecodeError:
print("\nError decoding server response for an item.")
error_entry = result_dict
error_entry["error"] = "Invalid JSON response from server"
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
except Exception as e:
print(f"\nUnexpected error processing an item: {e}")
error_entry = result_dict
error_entry["error"] = f"Unexpected error: {str(e)}"
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
# Progress bar updates automatically by iterating over tqdm(dataset)
# No need to close progress_bar explicitly when iterating directly
print("Processing finished.")
if __name__ == "__main__":
main()

View File

@ -461,6 +461,15 @@ class AsrDataModule:
)
return {"test": VoiceAssistant_cuts}
def test_cuts_voicebench(
self,
) -> CutSet:
logging.info("About to get test cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
)
return {"test": VoiceAssistant_cuts}
# def train_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get train cuts")
# VoiceAssistant_cuts = load_manifest_lazy(

View File

@ -0,0 +1,256 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
split=test_zh
llm_path=f5-tts/exp_zh/checkpoint-805000
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
vocoder=./bigvgan_v2_24khz_100band_256x
torchrun --nproc_per_node=2 \
f5-tts/infer_dist.py \
--output_dir $output_dir \
--batch_size 1 \
--num_workers 2 \
--llm-model-name-or-path $llm_path \
--flow-matching-model-path $model_path \
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--use-cosyvoice-semantic-token True \
--vocoder-dir $vocoder \
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
"""
import argparse
import json
import os
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
import whisper
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoTokenizer
from web_demo import get_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
# sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="test",
help="huggingface dataset split name",
)
parser.add_argument(
"--subset-name",
type=str,
default="commoneval",
help="subset name",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=2, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=2, help="prefetch for dataloader"
)
parser.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="Checkpoint name or path, default to %(default)r",
)
# parser.add_argument(
# "--top-k",
# type=int,
# default=50,
# help="top k for sampling",
# )
# parser.add_argument(
# "--top-p",
# type=float,
# default=0.95,
# help="top p for sampling",
# )
# parser.add_argument(
# "--temperature",
# type=float,
# default=0.8,
# help="temperature for sampling",
# )
add_model_arguments(parser)
args = parser.parse_args()
return args
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def preprocess(
messages,
tokenizer,
):
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="longest",
truncation=False,
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
def custom_collate(batch):
assert len(batch) == 1
audio = batch[0]["audio"]
assert audio["sampling_rate"] == 16000
result = {"audio": audio["array"]}
for keys in batch[0].keys():
if keys != "audio":
result[keys] = batch[0][keys]
return result
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
dataset = load_dataset(
"hlt-lab/voicebench",
args.subset_name,
split=args.split_name,
trust_remote_code=True,
)
model, tokenizer = get_model(args)
# tokenizer = AutoTokenizer.from_pretrained(args.llm_path_or_name)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=custom_collate,
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
input_ids, attention_mask = preprocess([message], tokenizer)
results_jsonl_file = open(
os.path.join(
args.output_dir,
f"results-{args.subset_name}-{args.split_name}-{rank}-audio.jsonl",
),
"w",
)
for batch in dataloader:
audio = batch["audio"]
audio = torch.from_numpy(audio).to(device).to(torch.float32)
fbank = whisper.log_mel_spectrogram(audio, device=device)
fbank = fbank.unsqueeze(0)
generated_ids = model.decode(
fbank, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
result_dict = {}
for key in batch.keys():
if key != "audio":
result_dict[key] = batch[key]
result_dict["response"] = hyps[0]
json.dump(result_dict, results_jsonl_file)
results_jsonl_file.write("\n")
if rank == 0:
progress_bar.update(world_size * args.batch_size)
if rank == 0:
progress_bar.close()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,103 @@
# server.py
import argparse
import os
from typing import List
import torch
import uvicorn
import whisper
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoTokenizer
from web_demo import get_model
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="Checkpoint name or path, default to %(default)r",
)
add_model_arguments(parser)
args = parser.parse_args()
return args
class SpeechRequest(BaseModel):
audio: List[float] # Expecting audio as a list of floats (raw waveform)
sampling_rate: int = 16000
class TextResponse(BaseModel):
text: str
def preprocess_prompt(tokenizer):
"""Preprocesses the prompt template."""
texts = [
tokenizer.apply_chat_template(
message, # Using the hardcoded message
tokenize=True,
add_generation_prompt=False, # Important for generation
chat_template=TEMPLATE,
padding=False, # No padding needed for single prompt
truncation=False,
)
]
input_ids = torch.tensor(texts, dtype=torch.long)
attention_mask = torch.ones_like(
input_ids, dtype=torch.bool
) # Mask is all True for the prompt
return input_ids, attention_mask
args = get_args()
model, tokenizer = get_model(args)
app = FastAPI()
device = torch.device("cuda")
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
prompt_input_ids, prompt_attention_mask = preprocess_prompt(tokenizer)
prompt_input_ids = prompt_input_ids.to(device)
prompt_attention_mask = prompt_attention_mask.to(device)
@app.post("/decode", response_model=TextResponse)
async def decode_speech(request: SpeechRequest):
"""
Receives audio waveform, processes it, and returns the decoded text.
"""
if request.sampling_rate != 16000:
raise HTTPException(
status_code=400, detail="Only 16kHz sampling rate is supported."
)
try:
audio_tensor = torch.tensor(request.audio, dtype=torch.float32).to(device)
fbank = whisper.log_mel_spectrogram(audio_tensor, device=device, n_mels=80)
fbank = fbank.unsqueeze(0)
with torch.no_grad():
generated_ids = model.decode(fbank, prompt_input_ids, prompt_attention_mask)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
response_text = hyps[0] if hyps else ""
return TextResponse(text=response_text)
except Exception as e:
print(f"Error during processing: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
if __name__ == "__main__":
print("Starting server...")
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -74,7 +74,8 @@ def get_model(params, device="cuda"):
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
codec_vocab_size = 4096 + 4
# codec_vocab_size = 4096 + 4
codec_vocab_size = 6561 + 4
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,