From cbf3af31fd2144bea91c66fd303f6f213a6740e7 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 05:37:11 +0000 Subject: [PATCH] add voicebench eval --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 46 ++++ .../SPEECH2SPEECH/qwen_omni/client.py | 142 ++++++++++ .../SPEECH2SPEECH/qwen_omni/data_module.py | 9 + .../SPEECH2SPEECH/qwen_omni/decode_dist.py | 256 ++++++++++++++++++ .../SPEECH2SPEECH/qwen_omni/server.py | 103 +++++++ .../SPEECH2SPEECH/qwen_omni/web_demo.py | 3 +- 6 files changed, 558 insertions(+), 1 deletion(-) create mode 100644 egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py create mode 100644 egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py create mode 100644 egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 6d8f54135..25cd79810 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -211,3 +211,49 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then --token2wav-path /workspace/CosyVoice2-0.5B \ --use-lora True fi + + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "stage 12: Decoding EN voicebench" + exp_dir=./qwen_omni/exp_speech2speech_en_continue + torchrun --nproc_per_node=2 \ + ./qwen_omni/decode_dist.py \ + --output-dir $exp_dir/log_voicebench \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --use-flash-attn True \ + --enable-speech-output True \ + --checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \ + --use-lora True --subset-name openbookqa --split-name test +fi + + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "stage 13: Server" + exp_dir=./qwen_omni/exp_speech2speech_en_continue + python3 ./qwen_omni/server.py \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \ + --use-flash-attn True \ + --enable-speech-output True \ + --use-lora True +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "stage 14: Client" + # datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa) + datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval sd-qa) + for dataset in ${datasets[@]}; do + # sd-qa should use usa split + if [ $dataset == "sd-qa" ]; then + split_name="usa" + else + split_name="test" + fi + echo $dataset $split_name + python3 ./qwen_omni/client.py \ + --subset-name $dataset --split-name $split_name \ + --output-dir test_result + done +fi diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py new file mode 100644 index 000000000..822d7d709 --- /dev/null +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py @@ -0,0 +1,142 @@ +# client.py +import argparse +import json +import os + +import requests +from datasets import load_dataset +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser(description="Speech-to-Text Client") + parser.add_argument( + "--server-url", + type=str, + default="http://localhost:8000", + help="URL of the FastAPI server", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="hlt-lab/voicebench", + help="Hugging Face dataset name", + ) + parser.add_argument( + "--subset-name", + type=str, + default="commoneval", # Adjust as needed + help="Dataset subset name", + ) + parser.add_argument( + "--split-name", + type=str, + default="test", # Adjust as needed + help="Dataset split name", + ) + parser.add_argument( + "--output-dir", required=True, type=str, help="Directory to save results" + ) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + output_filename = os.path.join( + args.output_dir, + f"{args.subset_name}-{args.split_name}.jsonl", + ) + server_decode_url = f"{args.server_url}/decode" + + print("Loading dataset...") + + dataset = load_dataset( + args.dataset_name, + args.subset_name, + split=args.split_name, + trust_remote_code=True, + ) + + print(f"Dataset loaded with {len(dataset)} samples.") + print(f"Sending requests to {server_decode_url}...") + print(f"Saving results to {output_filename}") + + with open(output_filename, "w", encoding="utf-8") as outfile: + # Iterate directly over the dataset + progress_bar = tqdm(dataset, desc="Processing", unit="samples") + for item in progress_bar: + + audio_info = item.get("audio") + assert ( + audio_info["sampling_rate"] == 16000 + ), f"Sampling rate is {audio_info['sampling_rate']}, not 16khz" + + # Prepare data for JSON serialization and server request + audio_array = audio_info["array"].tolist() # Convert numpy array to list + result_dict = {} + for key in item.keys(): + if key != "audio": + # Ensure other fields are JSON serializable + try: + # Attempt to serialize to catch issues early (optional) + json.dumps(item[key]) + result_dict[key] = item[key] + except (TypeError, OverflowError): + print( + f"Warning: Converting non-serializable key '{key}' to string." + ) + result_dict[key] = str( + item[key] + ) # Convert problematic types to string + + payload = { + "audio": audio_array, + "sampling_rate": 16000, + } + + try: + response = requests.post(server_decode_url, json=payload, timeout=60) + response.raise_for_status() + server_response = response.json() + decoded_text = server_response.get("text", "") + + # Add the response to the result dictionary + result_dict["response"] = decoded_text + print(result_dict) + # Write result to JSONL file + json.dump(result_dict, outfile, ensure_ascii=False) + outfile.write("\n") + + except requests.exceptions.RequestException as e: + print(f"\nError sending request for an item: {e}") + error_entry = result_dict # Use the data prepared so far + error_entry["error"] = str(e) + error_entry["response"] = "" + json.dump(error_entry, outfile, ensure_ascii=False) + outfile.write("\n") + except json.JSONDecodeError: + print("\nError decoding server response for an item.") + error_entry = result_dict + error_entry["error"] = "Invalid JSON response from server" + error_entry["response"] = "" + json.dump(error_entry, outfile, ensure_ascii=False) + outfile.write("\n") + except Exception as e: + print(f"\nUnexpected error processing an item: {e}") + error_entry = result_dict + error_entry["error"] = f"Unexpected error: {str(e)}" + error_entry["response"] = "" + json.dump(error_entry, outfile, ensure_ascii=False) + outfile.write("\n") + + # Progress bar updates automatically by iterating over tqdm(dataset) + + # No need to close progress_bar explicitly when iterating directly + + print("Processing finished.") + + +if __name__ == "__main__": + main() diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index b02c9f4bf..bc75bccd6 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -461,6 +461,15 @@ class AsrDataModule: ) return {"test": VoiceAssistant_cuts} + def test_cuts_voicebench( + self, + ) -> CutSet: + logging.info("About to get test cuts") + VoiceAssistant_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz" + ) + return {"test": VoiceAssistant_cuts} + # def train_cuts_en_vocalnet(self) -> CutSet: # logging.info("About to get train cuts") # VoiceAssistant_cuts = load_manifest_lazy( diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py new file mode 100644 index 000000000..dd69fce10 --- /dev/null +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py @@ -0,0 +1,256 @@ +# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# 2025 (authors: Yuekai Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py +""" Example Usage +split=test_zh +llm_path=f5-tts/exp_zh/checkpoint-805000 +huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x +vocoder=./bigvgan_v2_24khz_100band_256x +torchrun --nproc_per_node=2 \ + f5-tts/infer_dist.py \ + --output_dir $output_dir \ + --batch_size 1 \ + --num_workers 2 \ + --llm-model-name-or-path $llm_path \ + --flow-matching-model-path $model_path \ + --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --use-cosyvoice-semantic-token True \ + --vocoder-dir $vocoder \ + --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \ + --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct +""" + +import argparse +import json +import os +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import whisper +from datasets import load_dataset +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +from train import DEFAULT_SPEECH_TOKEN, add_model_arguments +from transformers import AutoTokenizer +from web_demo import get_model +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward + +# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party +# sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") +try: + torch.multiprocessing.set_start_method("spawn") +except RuntimeError: + pass + + +def get_args(): + parser = argparse.ArgumentParser(description="extract speech code") + parser.add_argument( + "--split-name", + type=str, + default="test", + help="huggingface dataset split name", + ) + parser.add_argument( + "--subset-name", + type=str, + default="commoneval", + help="subset name", + ) + parser.add_argument( + "--output-dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num-workers", type=int, default=2, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=2, help="prefetch for dataloader" + ) + parser.add_argument( + "--checkpoint-path", + type=str, + default=None, + help="Checkpoint name or path, default to %(default)r", + ) + # parser.add_argument( + # "--top-k", + # type=int, + # default=50, + # help="top k for sampling", + # ) + # parser.add_argument( + # "--top-p", + # type=float, + # default=0.95, + # help="top p for sampling", + # ) + # parser.add_argument( + # "--temperature", + # type=float, + # default=0.8, + # help="temperature for sampling", + # ) + add_model_arguments(parser) + args = parser.parse_args() + return args + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def preprocess( + messages, + tokenizer, +): + """Preprocesses the data for supervised fine-tuning.""" + texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + for i, msg in enumerate(messages): + texts.append( + tokenizer.apply_chat_template( + msg, + tokenize=True, + add_generation_prompt=False, + chat_template=TEMPLATE, + padding="longest", + truncation=False, + ) + ) + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] + else: + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] + + input_ids = torch.tensor(texts, dtype=torch.int) + + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + return input_ids, attention_mask + + +def custom_collate(batch): + assert len(batch) == 1 + audio = batch[0]["audio"] + assert audio["sampling_rate"] == 16000 + result = {"audio": audio["array"]} + for keys in batch[0].keys(): + if keys != "audio": + result[keys] = batch[0][keys] + return result + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + + assert torch.cuda.is_available() + world_size, local_rank, rank = init_distributed() + device = torch.device(f"cuda:{local_rank}") + + dataset = load_dataset( + "hlt-lab/voicebench", + args.subset_name, + split=args.split_name, + trust_remote_code=True, + ) + + model, tokenizer = get_model(args) + # tokenizer = AutoTokenizer.from_pretrained(args.llm_path_or_name) + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=custom_collate, + ) + + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + message = [ + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": ""}, + ] + input_ids, attention_mask = preprocess([message], tokenizer) + results_jsonl_file = open( + os.path.join( + args.output_dir, + f"results-{args.subset_name}-{args.split_name}-{rank}-audio.jsonl", + ), + "w", + ) + for batch in dataloader: + audio = batch["audio"] + audio = torch.from_numpy(audio).to(device).to(torch.float32) + fbank = whisper.log_mel_spectrogram(audio, device=device) + fbank = fbank.unsqueeze(0) + generated_ids = model.decode( + fbank, input_ids.to(device, dtype=torch.long), attention_mask.to(device) + ) + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + result_dict = {} + for key in batch.keys(): + if key != "audio": + result_dict[key] = batch[key] + result_dict["response"] = hyps[0] + json.dump(result_dict, results_jsonl_file) + results_jsonl_file.write("\n") + + if rank == 0: + progress_bar.update(world_size * args.batch_size) + + if rank == 0: + progress_bar.close() + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py new file mode 100644 index 000000000..2f06b923a --- /dev/null +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py @@ -0,0 +1,103 @@ +# server.py +import argparse +import os +from typing import List + +import torch +import uvicorn +import whisper +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from train import DEFAULT_SPEECH_TOKEN, add_model_arguments +from transformers import AutoTokenizer +from web_demo import get_model + + +def get_args(): + parser = argparse.ArgumentParser(description="extract speech code") + parser.add_argument( + "--checkpoint-path", + type=str, + default=None, + help="Checkpoint name or path, default to %(default)r", + ) + add_model_arguments(parser) + args = parser.parse_args() + return args + + +class SpeechRequest(BaseModel): + audio: List[float] # Expecting audio as a list of floats (raw waveform) + sampling_rate: int = 16000 + + +class TextResponse(BaseModel): + text: str + + +def preprocess_prompt(tokenizer): + """Preprocesses the prompt template.""" + texts = [ + tokenizer.apply_chat_template( + message, # Using the hardcoded message + tokenize=True, + add_generation_prompt=False, # Important for generation + chat_template=TEMPLATE, + padding=False, # No padding needed for single prompt + truncation=False, + ) + ] + input_ids = torch.tensor(texts, dtype=torch.long) + attention_mask = torch.ones_like( + input_ids, dtype=torch.bool + ) # Mask is all True for the prompt + return input_ids, attention_mask + + +args = get_args() +model, tokenizer = get_model(args) +app = FastAPI() + +device = torch.device("cuda") +message = [ + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": ""}, +] +TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" +prompt_input_ids, prompt_attention_mask = preprocess_prompt(tokenizer) +prompt_input_ids = prompt_input_ids.to(device) +prompt_attention_mask = prompt_attention_mask.to(device) + + +@app.post("/decode", response_model=TextResponse) +async def decode_speech(request: SpeechRequest): + """ + Receives audio waveform, processes it, and returns the decoded text. + """ + if request.sampling_rate != 16000: + raise HTTPException( + status_code=400, detail="Only 16kHz sampling rate is supported." + ) + + try: + audio_tensor = torch.tensor(request.audio, dtype=torch.float32).to(device) + fbank = whisper.log_mel_spectrogram(audio_tensor, device=device, n_mels=80) + fbank = fbank.unsqueeze(0) + + with torch.no_grad(): + generated_ids = model.decode(fbank, prompt_input_ids, prompt_attention_mask) + + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + response_text = hyps[0] if hyps else "" + + return TextResponse(text=response_text) + + except Exception as e: + print(f"Error during processing: {e}") + raise HTTPException(status_code=500, detail=f"Internal server error: {e}") + + +if __name__ == "__main__": + print("Starting server...") + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py index e33d2437d..1ad05b0a6 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py @@ -74,7 +74,8 @@ def get_model(params, device="cuda"): speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate ) - codec_vocab_size = 4096 + 4 + # codec_vocab_size = 4096 + 4 + codec_vocab_size = 6561 + 4 config = Qwen2Config( vocab_size=codec_vocab_size, hidden_size=1024,