mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add voicebench eval
This commit is contained in:
parent
89781b9bb1
commit
cbf3af31fd
@ -211,3 +211,49 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
|||||||
--token2wav-path /workspace/CosyVoice2-0.5B \
|
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||||
--use-lora True
|
--use-lora True
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||||
|
log "stage 12: Decoding EN voicebench"
|
||||||
|
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||||
|
torchrun --nproc_per_node=2 \
|
||||||
|
./qwen_omni/decode_dist.py \
|
||||||
|
--output-dir $exp_dir/log_voicebench \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--enable-speech-output True \
|
||||||
|
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
|
||||||
|
--use-lora True --subset-name openbookqa --split-name test
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||||
|
log "stage 13: Server"
|
||||||
|
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||||
|
python3 ./qwen_omni/server.py \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
|
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--enable-speech-output True \
|
||||||
|
--use-lora True
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||||
|
log "stage 14: Client"
|
||||||
|
# datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa)
|
||||||
|
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval sd-qa)
|
||||||
|
for dataset in ${datasets[@]}; do
|
||||||
|
# sd-qa should use usa split
|
||||||
|
if [ $dataset == "sd-qa" ]; then
|
||||||
|
split_name="usa"
|
||||||
|
else
|
||||||
|
split_name="test"
|
||||||
|
fi
|
||||||
|
echo $dataset $split_name
|
||||||
|
python3 ./qwen_omni/client.py \
|
||||||
|
--subset-name $dataset --split-name $split_name \
|
||||||
|
--output-dir test_result
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
142
egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py
Normal file
142
egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# client.py
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from datasets import load_dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Speech-to-Text Client")
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8000",
|
||||||
|
help="URL of the FastAPI server",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-name",
|
||||||
|
type=str,
|
||||||
|
default="hlt-lab/voicebench",
|
||||||
|
help="Hugging Face dataset name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--subset-name",
|
||||||
|
type=str,
|
||||||
|
default="commoneval", # Adjust as needed
|
||||||
|
help="Dataset subset name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="test", # Adjust as needed
|
||||||
|
help="Dataset split name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir", required=True, type=str, help="Directory to save results"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
output_filename = os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
f"{args.subset_name}-{args.split_name}.jsonl",
|
||||||
|
)
|
||||||
|
server_decode_url = f"{args.server_url}/decode"
|
||||||
|
|
||||||
|
print("Loading dataset...")
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
args.dataset_name,
|
||||||
|
args.subset_name,
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Dataset loaded with {len(dataset)} samples.")
|
||||||
|
print(f"Sending requests to {server_decode_url}...")
|
||||||
|
print(f"Saving results to {output_filename}")
|
||||||
|
|
||||||
|
with open(output_filename, "w", encoding="utf-8") as outfile:
|
||||||
|
# Iterate directly over the dataset
|
||||||
|
progress_bar = tqdm(dataset, desc="Processing", unit="samples")
|
||||||
|
for item in progress_bar:
|
||||||
|
|
||||||
|
audio_info = item.get("audio")
|
||||||
|
assert (
|
||||||
|
audio_info["sampling_rate"] == 16000
|
||||||
|
), f"Sampling rate is {audio_info['sampling_rate']}, not 16khz"
|
||||||
|
|
||||||
|
# Prepare data for JSON serialization and server request
|
||||||
|
audio_array = audio_info["array"].tolist() # Convert numpy array to list
|
||||||
|
result_dict = {}
|
||||||
|
for key in item.keys():
|
||||||
|
if key != "audio":
|
||||||
|
# Ensure other fields are JSON serializable
|
||||||
|
try:
|
||||||
|
# Attempt to serialize to catch issues early (optional)
|
||||||
|
json.dumps(item[key])
|
||||||
|
result_dict[key] = item[key]
|
||||||
|
except (TypeError, OverflowError):
|
||||||
|
print(
|
||||||
|
f"Warning: Converting non-serializable key '{key}' to string."
|
||||||
|
)
|
||||||
|
result_dict[key] = str(
|
||||||
|
item[key]
|
||||||
|
) # Convert problematic types to string
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"audio": audio_array,
|
||||||
|
"sampling_rate": 16000,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(server_decode_url, json=payload, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
server_response = response.json()
|
||||||
|
decoded_text = server_response.get("text", "")
|
||||||
|
|
||||||
|
# Add the response to the result dictionary
|
||||||
|
result_dict["response"] = decoded_text
|
||||||
|
print(result_dict)
|
||||||
|
# Write result to JSONL file
|
||||||
|
json.dump(result_dict, outfile, ensure_ascii=False)
|
||||||
|
outfile.write("\n")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"\nError sending request for an item: {e}")
|
||||||
|
error_entry = result_dict # Use the data prepared so far
|
||||||
|
error_entry["error"] = str(e)
|
||||||
|
error_entry["response"] = ""
|
||||||
|
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||||
|
outfile.write("\n")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("\nError decoding server response for an item.")
|
||||||
|
error_entry = result_dict
|
||||||
|
error_entry["error"] = "Invalid JSON response from server"
|
||||||
|
error_entry["response"] = ""
|
||||||
|
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||||
|
outfile.write("\n")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nUnexpected error processing an item: {e}")
|
||||||
|
error_entry = result_dict
|
||||||
|
error_entry["error"] = f"Unexpected error: {str(e)}"
|
||||||
|
error_entry["response"] = ""
|
||||||
|
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||||
|
outfile.write("\n")
|
||||||
|
|
||||||
|
# Progress bar updates automatically by iterating over tqdm(dataset)
|
||||||
|
|
||||||
|
# No need to close progress_bar explicitly when iterating directly
|
||||||
|
|
||||||
|
print("Processing finished.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -461,6 +461,15 @@ class AsrDataModule:
|
|||||||
)
|
)
|
||||||
return {"test": VoiceAssistant_cuts}
|
return {"test": VoiceAssistant_cuts}
|
||||||
|
|
||||||
|
def test_cuts_voicebench(
|
||||||
|
self,
|
||||||
|
) -> CutSet:
|
||||||
|
logging.info("About to get test cuts")
|
||||||
|
VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
|
||||||
|
)
|
||||||
|
return {"test": VoiceAssistant_cuts}
|
||||||
|
|
||||||
# def train_cuts_en_vocalnet(self) -> CutSet:
|
# def train_cuts_en_vocalnet(self) -> CutSet:
|
||||||
# logging.info("About to get train cuts")
|
# logging.info("About to get train cuts")
|
||||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
|
256
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py
Normal file
256
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||||
|
# 2025 (authors: Yuekai Zhang)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
|
||||||
|
""" Example Usage
|
||||||
|
split=test_zh
|
||||||
|
llm_path=f5-tts/exp_zh/checkpoint-805000
|
||||||
|
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
|
||||||
|
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
|
||||||
|
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
||||||
|
vocoder=./bigvgan_v2_24khz_100band_256x
|
||||||
|
torchrun --nproc_per_node=2 \
|
||||||
|
f5-tts/infer_dist.py \
|
||||||
|
--output_dir $output_dir \
|
||||||
|
--batch_size 1 \
|
||||||
|
--num_workers 2 \
|
||||||
|
--llm-model-name-or-path $llm_path \
|
||||||
|
--flow-matching-model-path $model_path \
|
||||||
|
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||||
|
--use-cosyvoice-semantic-token True \
|
||||||
|
--vocoder-dir $vocoder \
|
||||||
|
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
|
||||||
|
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import whisper
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from web_demo import get_model
|
||||||
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
|
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
|
||||||
|
# sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
try:
|
||||||
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="extract speech code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="test",
|
||||||
|
help="huggingface dataset split name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--subset-name",
|
||||||
|
type=str,
|
||||||
|
default="commoneval",
|
||||||
|
help="subset name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir", required=True, type=str, help="dir to save result"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="batch size (per-device) for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers", type=int, default=2, help="workers for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefetch", type=int, default=2, help="prefetch for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Checkpoint name or path, default to %(default)r",
|
||||||
|
)
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--top-k",
|
||||||
|
# type=int,
|
||||||
|
# default=50,
|
||||||
|
# help="top k for sampling",
|
||||||
|
# )
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--top-p",
|
||||||
|
# type=float,
|
||||||
|
# default=0.95,
|
||||||
|
# help="top p for sampling",
|
||||||
|
# )
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--temperature",
|
||||||
|
# type=float,
|
||||||
|
# default=0.8,
|
||||||
|
# help="temperature for sampling",
|
||||||
|
# )
|
||||||
|
add_model_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed():
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
rank = int(os.environ.get("RANK", 0))
|
||||||
|
print(
|
||||||
|
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||||
|
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
messages,
|
||||||
|
tokenizer,
|
||||||
|
):
|
||||||
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
|
texts = []
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
texts.append(
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
msg,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=TEMPLATE,
|
||||||
|
padding="longest",
|
||||||
|
truncation=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
max_len_texts = max([len(text) for text in texts])
|
||||||
|
if tokenizer.padding_side == "right":
|
||||||
|
texts = [
|
||||||
|
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
texts = [
|
||||||
|
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
|
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
return input_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def custom_collate(batch):
|
||||||
|
assert len(batch) == 1
|
||||||
|
audio = batch[0]["audio"]
|
||||||
|
assert audio["sampling_rate"] == 16000
|
||||||
|
result = {"audio": audio["array"]}
|
||||||
|
for keys in batch[0].keys():
|
||||||
|
if keys != "audio":
|
||||||
|
result[keys] = batch[0][keys]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
world_size, local_rank, rank = init_distributed()
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
"hlt-lab/voicebench",
|
||||||
|
args.subset_name,
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
model, tokenizer = get_model(args)
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained(args.llm_path_or_name)
|
||||||
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch,
|
||||||
|
collate_fn=custom_collate,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_steps = len(dataset)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||||
|
|
||||||
|
message = [
|
||||||
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
{"role": "assistant", "content": ""},
|
||||||
|
]
|
||||||
|
input_ids, attention_mask = preprocess([message], tokenizer)
|
||||||
|
results_jsonl_file = open(
|
||||||
|
os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
f"results-{args.subset_name}-{args.split_name}-{rank}-audio.jsonl",
|
||||||
|
),
|
||||||
|
"w",
|
||||||
|
)
|
||||||
|
for batch in dataloader:
|
||||||
|
audio = batch["audio"]
|
||||||
|
audio = torch.from_numpy(audio).to(device).to(torch.float32)
|
||||||
|
fbank = whisper.log_mel_spectrogram(audio, device=device)
|
||||||
|
fbank = fbank.unsqueeze(0)
|
||||||
|
generated_ids = model.decode(
|
||||||
|
fbank, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
|
)
|
||||||
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
result_dict = {}
|
||||||
|
for key in batch.keys():
|
||||||
|
if key != "audio":
|
||||||
|
result_dict[key] = batch[key]
|
||||||
|
result_dict["response"] = hyps[0]
|
||||||
|
json.dump(result_dict, results_jsonl_file)
|
||||||
|
results_jsonl_file.write("\n")
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(world_size * args.batch_size)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
103
egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py
Normal file
103
egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# server.py
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import uvicorn
|
||||||
|
import whisper
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from web_demo import get_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="extract speech code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Checkpoint name or path, default to %(default)r",
|
||||||
|
)
|
||||||
|
add_model_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechRequest(BaseModel):
|
||||||
|
audio: List[float] # Expecting audio as a list of floats (raw waveform)
|
||||||
|
sampling_rate: int = 16000
|
||||||
|
|
||||||
|
|
||||||
|
class TextResponse(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_prompt(tokenizer):
|
||||||
|
"""Preprocesses the prompt template."""
|
||||||
|
texts = [
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
message, # Using the hardcoded message
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=False, # Important for generation
|
||||||
|
chat_template=TEMPLATE,
|
||||||
|
padding=False, # No padding needed for single prompt
|
||||||
|
truncation=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
input_ids = torch.tensor(texts, dtype=torch.long)
|
||||||
|
attention_mask = torch.ones_like(
|
||||||
|
input_ids, dtype=torch.bool
|
||||||
|
) # Mask is all True for the prompt
|
||||||
|
return input_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
model, tokenizer = get_model(args)
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
message = [
|
||||||
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
{"role": "assistant", "content": ""},
|
||||||
|
]
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
|
prompt_input_ids, prompt_attention_mask = preprocess_prompt(tokenizer)
|
||||||
|
prompt_input_ids = prompt_input_ids.to(device)
|
||||||
|
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/decode", response_model=TextResponse)
|
||||||
|
async def decode_speech(request: SpeechRequest):
|
||||||
|
"""
|
||||||
|
Receives audio waveform, processes it, and returns the decoded text.
|
||||||
|
"""
|
||||||
|
if request.sampling_rate != 16000:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Only 16kHz sampling rate is supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio_tensor = torch.tensor(request.audio, dtype=torch.float32).to(device)
|
||||||
|
fbank = whisper.log_mel_spectrogram(audio_tensor, device=device, n_mels=80)
|
||||||
|
fbank = fbank.unsqueeze(0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = model.decode(fbank, prompt_input_ids, prompt_attention_mask)
|
||||||
|
|
||||||
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
response_text = hyps[0] if hyps else ""
|
||||||
|
|
||||||
|
return TextResponse(text=response_text)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during processing: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Starting server...")
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
@ -74,7 +74,8 @@ def get_model(params, device="cuda"):
|
|||||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
codec_vocab_size = 4096 + 4
|
# codec_vocab_size = 4096 + 4
|
||||||
|
codec_vocab_size = 6561 + 4
|
||||||
config = Qwen2Config(
|
config = Qwen2Config(
|
||||||
vocab_size=codec_vocab_size,
|
vocab_size=codec_vocab_size,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user