2025-05-19 01:31:21 +00:00

125 lines
4.0 KiB
Python

# server.py
import argparse
import os
from typing import List
import torch
import uvicorn
import whisper
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoTokenizer
from web_demo import get_model
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--prompt-template",
type=str,
default=None,
help="Prompt template",
)
add_model_arguments(parser)
args = parser.parse_args()
return args
class SpeechRequest(BaseModel):
audio: List[float] # Expecting audio as a list of floats (raw waveform)
sampling_rate: int = 16000
class TextResponse(BaseModel):
text: str
def preprocess_prompt(tokenizer):
"""Preprocesses the prompt template."""
texts = [
tokenizer.apply_chat_template(
message, # Using the hardcoded message
tokenize=True,
add_generation_prompt=False, # Important for generation
chat_template=TEMPLATE,
padding=False, # No padding needed for single prompt
truncation=False,
)
]
input_ids = torch.tensor(texts, dtype=torch.long)
attention_mask = torch.ones_like(
input_ids, dtype=torch.bool
) # Mask is all True for the prompt
return input_ids, attention_mask
args = get_args()
model, tokenizer = get_model(args)
app = FastAPI()
device = torch.device("cuda")
if args.prompt_template is None:
template = f"{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "qa":
template = f"Answer the following question:\n\n{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "continuation":
template = f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "asr":
template = (
f"Repeat the following text, without any explanation: {DEFAULT_SPEECH_TOKEN}"
)
elif args.prompt_template == "mt":
template = f"Please translate the text to Chinese. Your response should only include the Chinese translation, without any additional words:\n\n{DEFAULT_SPEECH_TOKEN}"
else:
raise ValueError(f"Invalid prompt template: {args.prompt_template}")
print("Using template:", template)
message = [
{"role": "user", "content": template},
{"role": "assistant", "content": ""},
]
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
prompt_input_ids, prompt_attention_mask = preprocess_prompt(tokenizer)
prompt_input_ids = prompt_input_ids.to(device)
prompt_attention_mask = prompt_attention_mask.to(device)
@app.post("/decode", response_model=TextResponse)
async def decode_speech(request: SpeechRequest):
"""
Receives audio waveform, processes it, and returns the decoded text.
"""
if request.sampling_rate != 16000:
raise HTTPException(
status_code=400, detail="Only 16kHz sampling rate is supported."
)
try:
audio_tensor = torch.tensor(request.audio, dtype=torch.float32).to(device)
fbank = whisper.log_mel_spectrogram(audio_tensor, device=device, n_mels=80)
fbank = fbank.unsqueeze(0)
with torch.no_grad():
generated_ids = model.decode(fbank, prompt_input_ids, prompt_attention_mask)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
response_text = hyps[0] if hyps else ""
return TextResponse(text=response_text)
except Exception as e:
print(f"Error during processing: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
if __name__ == "__main__":
print("Starting server...")
uvicorn.run(app, host="0.0.0.0", port=8000)