From 71a0a442a6366edbdbc567448255aec63ef688cd Mon Sep 17 00:00:00 2001 From: root Date: Fri, 25 Apr 2025 10:05:07 +0000 Subject: [PATCH] add history cache --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 26 ++++ .../SPEECH2SPEECH/slam_omni/model.py | 2 +- .../SPEECH2SPEECH/slam_omni/web_demo.py | 119 +++++++++++++----- 3 files changed, 116 insertions(+), 31 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index e0a2fa507..1b49daa65 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -107,4 +107,30 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then # --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ # --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \ +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "stage 6: " + export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice + exp_dir=./slam_omni/exp_speech2speech_rerun + python3 ./slam_omni/web_demo.py \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --checkpoint-path $exp_dir/epoch-998.pt \ + --use-flash-attn True \ + --enable-speech-output True \ + --asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \ + --use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share + +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "stage 7: " + model_path=local/sherpa-onnx-paraformer-zh-2023-09-14 + + if [ ! -d $model_path ]; then + pip install sherpa-onnx + wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local + fi fi \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index c5f31226d..0cc93c237 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -630,7 +630,7 @@ class SPEECH_LLM(nn.Module): next_token_ids = topk_sampling( next_token_logits, ) - print(next_token_ids, "next_token_ids", t, next_token_ids.shape) + # print(next_token_ids, "next_token_ids", t, next_token_ids.shape) if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id: break # current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py index ebcbc36ed..3155174fb 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py @@ -6,7 +6,7 @@ import gradio as gr import soundfile as sf import gradio.processing_utils as processing_utils - +import tempfile from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config from gradio_client import utils as client_utils @@ -17,8 +17,11 @@ from peft import LoraConfig, get_peft_model from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from model import SPEECH_LLM, EncoderProjector from train import DEFAULT_SPEECH_TOKEN, add_model_arguments - +import sherpa_onnx from cosyvoice.cli.cosyvoice import CosyVoice +import sys +sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS') + def get_model(params, device="cuda"): """Load and prepare the speech-to-speech model.""" @@ -177,26 +180,26 @@ def preprocess( return input_ids, attention_mask -def _launch_demo(args, model, tokenizer, token2wav_model): +def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): def format_history(history: list): messages = [] for item in history: if isinstance(item["content"], str): messages.append({"role": item['role'], "content": item['content']}) - elif item["role"] == "user" and (isinstance(item["content"], list) or - isinstance(item["content"], tuple)): - file_path = item["content"][0] - # TODO: check if the file_path's transcript is already in the history - mime_type = client_utils.get_mimetype(file_path) - if mime_type.startswith("audio"): - messages.append({ - "role": - item['role'], - "content": item["content"][1] # append audio transcript here - }) + # elif item["role"] == "user" and (isinstance(item["content"], list) or + # isinstance(item["content"], tuple)): + # file_path = item["content"][0] + # # TODO: check if the file_path's transcript is already in the history + # mime_type = client_utils.get_mimetype(file_path) + # if mime_type.startswith("audio"): + # messages.append({ + # "role": + # item['role'], + # "content": item["content"][1] # append audio transcript here + # }) print('predict history: ', messages) - messages = messages[-2:] # TODO: WAR: add history later + # messages = messages[-2:] # TODO: WAR: add history later return messages def decode( @@ -214,8 +217,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model): dtype = torch.float32 device = model.llm.device - feature = feature.to(device, dtype=dtype).transpose(1, 2) - assert feature.shape[2] == 80 + feature = feature.to(device, dtype=dtype)#.transpose(1, 2) + # assert feature.shape[2] == 80 input_ids, attention_mask = preprocess([messages], tokenizer) @@ -224,7 +227,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model): ) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - yield {"type": "text", "data": hyps} + # print('hyps: ', hyps, 23333333333333333333333333) + yield {"type": "text", "data": hyps[0]} + # yield {"type": "text", "data": hyps} audio_tokens = [token for token in audio_tokens if token < 4096] audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) @@ -232,6 +237,10 @@ def _launch_demo(args, model, tokenizer, token2wav_model): audio = audio_hat.squeeze(0).cpu().numpy() # sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050) audio = np.array(audio * 32767).astype(np.int16) + # yield {"type": "audio", "data": (22050, audio)} + # with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: + # sf.write(tmpfile.name, audio, 22050, format="WAV") + # audio_path = tmpfile.name wav_io = io.BytesIO() sf.write(wav_io, audio, samplerate=22050, format="WAV") wav_io.seek(0) @@ -249,18 +258,22 @@ def _launch_demo(args, model, tokenizer, token2wav_model): gr.update(visible=False), # submit_btn gr.update(visible=True), # stop_btn ) - - assert audio is not None - # get audio transcript here + print(2333, history, audio) + history.append({"role": "user", "content": (audio,)}) history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}) history.append({"role": "assistant", "content": ""}) - formatted_history = format_history(history=history) + formatted_history = format_history(history=history) # only keep string text format - # audio_transcript = get_audio_transcript(audio) - audio_transcript = "audio transcript" - history[-2]["content"] = (audio, audio_transcript) - fbank = whisper.log_mel_spectrogram(audio, model.llm.device) - print('fbank: ', fbank.shape) + assert audio is not None + audio_transcript = get_transcript( + audio, + asr_model, + ) + print('audio_transcript: ', audio_transcript) + history[-2]["content"] = audio_transcript + + fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device) + fbank = fbank.unsqueeze(0) assert fbank.ndim == 3 # history.append({"role": "assistant", "content": ""}) @@ -342,6 +355,10 @@ def _get_args(): type=str, default=None, help='Token2Wav path, default to %(default)r') + parser.add_argument('--asr-model-dir', + type=str, + default=None, + help='ASR model dir, default to %(default)r') parser.add_argument('--flash-attn2', action='store_true', default=False, @@ -354,14 +371,56 @@ def _get_args(): action='store_true', default=False, help='Automatically launch the interface in a new tab on the default browser.') - parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.') + parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.') parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.') add_model_arguments(parser) args = parser.parse_args() return args + +def read_wave(wave_filename: str): + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and can be of type + 32-bit floating point PCM. Its sample rate does not need to be 24kHz. + + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, + which are normalized to the range [-1, 1]. + - Sample rate of the wave file. + """ + + samples, sample_rate = sf.read(wave_filename, dtype="float32") + assert ( + samples.ndim == 1 + ), f"Expected single channel, but got {samples.ndim} channels." + + samples_float32 = samples.astype(np.float32) + + return samples_float32, sample_rate + +def get_transcript(audio_path, recognizer): + samples, sample_rate = read_wave(audio_path) + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + recognizer.decode_streams([s]) + return s.result.text + if __name__ == "__main__": args = _get_args() model, tokenizer = get_model(args) - cosyvoice = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False) - _launch_demo(args, model, tokenizer, cosyvoice) \ No newline at end of file + token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False) + + asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=f"{args.asr_model_dir}/model.int8.onnx", + tokens=f"{args.asr_model_dir}/tokens.txt", + num_threads=2, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + debug=False, + ) + + _launch_demo(args, model, tokenizer, token2wav, asr_model) \ No newline at end of file