From 47920c233698d001eb1ca6f7436d1d792ea358c7 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 25 Apr 2025 16:05:37 +0800 Subject: [PATCH] add gradio demo --- .../SPEECH2SPEECH/slam_omni/train.py | 6 + .../SPEECH2SPEECH/slam_omni/web_demo.py | 267 +++++++++++++++--- 2 files changed, 226 insertions(+), 47 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index f5356dc43..3b971dd89 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -99,6 +99,12 @@ def set_batch_count(model: nn.Module, batch_count: float) -> None: def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--remove-whisper-encoder-input-length-restriction", + type=str2bool, + default=True, + help="replace whisper encoder forward method to remove input length restriction", + ) parser.add_argument( "--llm-path-or-name", type=str, diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py index f856bf26f..ebcbc36ed 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py @@ -7,26 +7,177 @@ import soundfile as sf import gradio.processing_utils as processing_utils -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config from gradio_client import utils as client_utils from argparse import ArgumentParser +import whisper +import torch +from peft import LoraConfig, get_peft_model +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward +from model import SPEECH_LLM, EncoderProjector +from train import DEFAULT_SPEECH_TOKEN, add_model_arguments -def _load_model_processor(args): +from cosyvoice.cli.cosyvoice import CosyVoice - # Check if flash-attn2 flag is enabled and load model accordingly - if args.flash_attn2: - # model = Qwen2_5OmniForConditionalGeneration.from_pretrained(args.checkpoint_path, - # torch_dtype='auto', - # attn_implementation='flash_attention_2', - # device_map=device_map) - # else: - # model = Qwen2_5OmniForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map, torch_dtype='auto') +def get_model(params, device="cuda"): + """Load and prepare the speech-to-speech model.""" + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() - # processor = Qwen2_5OmniProcessor.from_pretrained(args.checkpoint_path) - return model, processor + whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") + speech_encoder = whisper_model.encoder + speech_encoder_dim = whisper_model.dims.n_audio_state + tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) -def _launch_demo(args, model, processor): + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + else: + attn_implementation = "eager" + + llm = AutoModelForCausalLM.from_pretrained( + params.llm_path_or_name, + attn_implementation=attn_implementation, + torch_dtype=torch.float16, + ) + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], + task_type="CAUSAL_LM", + ) + llm = get_peft_model(llm, lora_config) + llm.print_trainable_parameters() + + special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} + tokenizer.add_special_tokens(special_tokens_dict) + llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") + llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") + llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( + DEFAULT_SPEECH_TOKEN + ) + + encoder_projector = EncoderProjector( + speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate + ) + + codec_vocab_size = 4096 + 4 + config = Qwen2Config( + vocab_size=codec_vocab_size, + hidden_size=1024, + num_hidden_layers=12, + num_attention_heads=16, + num_key_value_heads=16, + intermediate_size=2048, + max_position_embeddings=4096, + ) + codec_lm = AutoModelForCausalLM.from_config( + config=config, + attn_implementation=attn_implementation, + torch_dtype=torch.float16 + ) + codec_lm.resize_token_embeddings(codec_vocab_size) + codec_lm.vocab_size = codec_vocab_size + codec_lm.config.pad_token_id = codec_vocab_size - 1 + codec_lm.config.eos_token_id = codec_vocab_size - 2 + codec_lm.config.bos_token_id = codec_vocab_size - 3 + codec_lm.config.mask_token_id = codec_vocab_size - 4 + + model = SPEECH_LLM( + speech_encoder, + llm, + encoder_projector, + codec_lm, + codec_lm_padding_side= "left" if params.use_flash_attn else "right", + ) + + checkpoint = torch.load( + f"{params.checkpoint_path}", map_location="cpu" + ) + model.load_state_dict(checkpoint, strict=False) + + model.to(device) + model.eval() + return model, tokenizer + + +def audio_decode_cosyvoice(audio_tokens, codec_decoder): + """ + Generate audio from tokens with optional tone and prompt embedding. + + Args: + audio_tokens (list): List of audio tokens to be processed. + codec_decoder: Codec decoder for generating audio. + + Returns: + torch.Tensor: Generated audio waveform. + """ + flow_embedding = codec_decoder.frontend.spk2info['中文女']['embedding'] + flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32) + prompt_speech_feat = torch.zeros(1, 0, 80) + tts_mel, _ = codec_decoder.model.flow.inference(token=audio_tokens.to(codec_decoder.model.device), + token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), + prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device), + prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), + prompt_feat=prompt_speech_feat.to(codec_decoder.model.device), + prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), + embedding=flow_embedding.to(codec_decoder.model.device), + flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),) + + + audio_hat, _ = codec_decoder.model.hift.inference(speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)) + + return audio_hat + +def preprocess( + messages, + tokenizer, +): + """Preprocesses the data for supervised fine-tuning.""" + texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + for i, msg in enumerate(messages): + texts.append( + tokenizer.apply_chat_template( + msg, + tokenize=True, + add_generation_prompt=False, + chat_template=TEMPLATE, + padding="longest", + truncation=False, + ) + ) + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] + else: + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] + + input_ids = torch.tensor(texts, dtype=torch.int) + + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + return input_ids, attention_mask + + +def _launch_demo(args, model, tokenizer, token2wav_model): def format_history(history: list): messages = [] @@ -36,42 +187,58 @@ def _launch_demo(args, model, processor): elif item["role"] == "user" and (isinstance(item["content"], list) or isinstance(item["content"], tuple)): file_path = item["content"][0] - + # TODO: check if the file_path's transcript is already in the history mime_type = client_utils.get_mimetype(file_path) if mime_type.startswith("audio"): messages.append({ "role": item['role'], - "content": [{ - "type": "audio", - "audio": file_path, - }] + "content": item["content"][1] # append audio transcript here }) + print('predict history: ', messages) + messages = messages[-2:] # TODO: WAR: add history later return messages - def predict(messages): - print('predict history: ', messages) + def decode( + model, + token2wav_model, + tokenizer, + feature, + messages, + ): + """Decode one + Returns: + pass + """ - text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + dtype = torch.float32 + device = model.llm.device - audios = [msg['content'][0]['audio'] for msg in messages if msg['role'] == 'user' and isinstance(msg['content'], list) and msg['content'][0]['type'] == 'audio'] + feature = feature.to(device, dtype=dtype).transpose(1, 2) + assert feature.shape[2] == 80 + + input_ids, attention_mask = preprocess([messages], tokenizer) - inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True) - inputs = inputs.to(model.device).to(model.dtype) + generated_ids, audio_tokens = model.decode_with_speech_output( + feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) + ) - text_ids, audio = model.generate(**inputs) - - response = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - response = response[0].split("\n")[-1] - yield {"type": "text", "data": response} + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + yield {"type": "text", "data": hyps} + audio_tokens = [token for token in audio_tokens if token < 4096] + audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) + audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) + audio = audio_hat.squeeze(0).cpu().numpy() + # sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050) audio = np.array(audio * 32767).astype(np.int16) wav_io = io.BytesIO() - sf.write(wav_io, audio, samplerate=24000, format="WAV") + sf.write(wav_io, audio, samplerate=22050, format="WAV") wav_io.seek(0) wav_bytes = wav_io.getvalue() audio_path = processing_utils.save_bytes_to_cache( wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE) + yield {"type": "audio", "data": audio_path} def media_predict(audio, history): @@ -83,18 +250,21 @@ def _launch_demo(args, model, processor): gr.update(visible=True), # stop_btn ) - files = [audio] - - for f in files: - if f: - history.append({"role": "user", "content": (f, )}) - + assert audio is not None + # get audio transcript here + history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}) + history.append({"role": "assistant", "content": ""}) formatted_history = format_history(history=history) + # audio_transcript = get_audio_transcript(audio) + audio_transcript = "audio transcript" + history[-2]["content"] = (audio, audio_transcript) + fbank = whisper.log_mel_spectrogram(audio, model.llm.device) + print('fbank: ', fbank.shape) + assert fbank.ndim == 3 - history.append({"role": "assistant", "content": ""}) - - for chunk in predict(formatted_history): + # history.append({"role": "assistant", "content": ""}) + for chunk in decode(model, token2wav_model, tokenizer, fbank, formatted_history): if chunk["type"] == "text": history[-1]["content"] = chunk["data"] yield ( @@ -123,10 +293,9 @@ def _launch_demo(args, model, processor): with gr.Column(scale=1): microphone = gr.Audio(sources=['microphone'], type="filepath") - submit_btn = gr.Button(get_text("Submit", "提交"), - variant="primary") - stop_btn = gr.Button(get_text("Stop", "停止"), visible=False) - clear_btn = gr.Button(get_text("Clear History", "清除历史")) + submit_btn = gr.Button("Submit", variant="primary") + stop_btn = gr.Button("Stop", visible=False) + clear_btn = gr.Button("Clear History") with gr.Column(scale=2): media_chatbot = gr.Chatbot(height=650, type="messages") @@ -169,7 +338,10 @@ def _get_args(): type=str, default=None, help='Checkpoint name or path, default to %(default)r') - + parser.add_argument('--token2wav-path', + type=str, + default=None, + help='Token2Wav path, default to %(default)r') parser.add_argument('--flash-attn2', action='store_true', default=False, @@ -184,11 +356,12 @@ def _get_args(): help='Automatically launch the interface in a new tab on the default browser.') parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.') parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.') - + add_model_arguments(parser) args = parser.parse_args() return args if __name__ == "__main__": args = _get_args() - model, processor = _load_model_processor(args) - _launch_demo(args, model, processor) \ No newline at end of file + model, tokenizer = get_model(args) + cosyvoice = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False) + _launch_demo(args, model, tokenizer, cosyvoice) \ No newline at end of file