From 89781b9bb185f307bb692ed5ff7629d28a9248a3 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 10:06:59 +0000 Subject: [PATCH] add cosyvoice2 decode --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 19 ++ .../SPEECH2SPEECH/qwen_omni/data_module.py | 9 +- .../SPEECH2SPEECH/qwen_omni/decode.py | 177 +++++++++++++----- 3 files changed, 155 insertions(+), 50 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index fcdfdd69f..6d8f54135 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -192,3 +192,22 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then --use-flash-attn True --on-the-fly-feats True \ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True fi + + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "stage 11: Decoding EN, only support batch_size=1 for now." + exp_dir=./qwen_omni/exp_speech2speech_en_continue + # cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd - + python3 ./qwen_omni/decode.py \ + --max-duration 1 \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --epoch 997 --avg 1 \ + --manifest-dir data/fbank \ + --use-flash-attn True \ + --method e2e-epoch4_speech2speech \ + --enable-speech-output True \ + --token2wav-path /workspace/CosyVoice2-0.5B \ + --use-lora True +fi diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index b0b039416..b02c9f4bf 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -47,9 +47,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from speech_dataset import K2SpeechRecognitionDataset from torch.utils.data import DataLoader - from utils import str2bool + class _SeedWorkers: def __init__(self, seed: int): self.seed = seed @@ -457,9 +457,10 @@ class AsrDataModule: def test_cuts_en_vocalnet(self) -> CutSet: logging.info("About to get test cuts") VoiceAssistant_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz" + self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz" ) - return VoiceAssistant_cuts + return {"test": VoiceAssistant_cuts} + # def train_cuts_en_vocalnet(self) -> CutSet: # logging.info("About to get train cuts") # VoiceAssistant_cuts = load_manifest_lazy( @@ -481,4 +482,4 @@ class AsrDataModule: # VoiceAssistant_cuts = load_manifest_lazy( # self.args.manifest_dir / "cuts_debug.jsonl.gz" # ) - # return VoiceAssistant_cuts \ No newline at end of file + # return VoiceAssistant_cuts diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py index e4dccf081..793b32112 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py @@ -55,7 +55,8 @@ import torch import torch.nn as nn import transformers import whisper -from cosyvoice.cli.cosyvoice import CosyVoice +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav from data_module import AsrDataModule from lhotse.cut import Cut from model import SPEECH_LLM, EncoderProjector @@ -75,6 +76,57 @@ from icefall.utils import ( sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") +def audio_decode_cosyvoice2( + audio_tokens, prompt_text, prompt_speech_path, codec_decoder +): + """ + Generate audio from tokens with optional tone and prompt embedding. + + Args: + audio_tokens (list): List of audio tokens to be processed. + model_config: Configuration object containing vocab settings. + codec_decoder: Codec decoder for generating audio. + tone_dir (str): The tone directory or setting. + audio_prompt_path (str, optional): Path to the audio prompt file. Required when tone_dir is not "default_tone". + code_layer (int, optional): Number of code layers. Defaults to 1. + num_latency_tokens (int, optional): Number of latency tokens to ignore. Defaults to 0. + speed (float, optional): Speed factor for audio generation. Defaults to 1.0. + + Returns: + torch.Tensor: Generated audio waveform. + """ + prompt_speech_16k = load_wav(prompt_speech_path, 16000) + model_inputs_dict = codec_decoder.frontend.frontend_zero_shot( + "empty", prompt_text, prompt_speech_16k, 24000 + ) + tts_mel, _ = codec_decoder.model.flow.inference( + token=audio_tokens.to(codec_decoder.model.device), + token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( + codec_decoder.model.device + ), + prompt_token=model_inputs_dict["flow_prompt_speech_token"].to( + codec_decoder.model.device + ), + prompt_token_len=torch.tensor( + [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32 + ).to(codec_decoder.model.device), + prompt_feat=model_inputs_dict["prompt_speech_feat"].to( + codec_decoder.model.device + ), + prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to( + codec_decoder.model.device + ), + embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device), + finalize=True, + ) + + audio_hat, _ = codec_decoder.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + return audio_hat + + def audio_decode_cosyvoice(audio_tokens, codec_decoder): """ Generate audio from tokens with optional tone and prompt embedding. @@ -180,7 +232,9 @@ def get_model(params, device): attn_implementation = "eager" torch_dtype = torch.float16 - codec_vocab_size = 4096 + 4 + # TODO: FIX ME + # codec_vocab_size = 4096 + 4 + codec_vocab_size = 6561 + 4 config = Qwen2Config( vocab_size=codec_vocab_size, hidden_size=1024, @@ -346,6 +400,20 @@ def get_parser(): help="The path to the token2wav model", ) + parser.add_argument( + "--prompt_text", + type=str, + default="Romeo and Juliet might be the most famous act of William Shakespeare.", + help="The prompt text", + ) + + parser.add_argument( + "--prompt_speech_path", + type=str, + default="./assets/common_voice_en_2586258.wav", + help="The path to the prompt speech", + ) + add_model_arguments(parser) return parser @@ -437,36 +505,42 @@ def decode_one_batch( 2, ) - chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] + # chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] - questions_with_history = [ - cut.custom["question"] for cut in batch["supervisions"]["cut"] - ] - history_contexts = [ - question.rsplit(":", 1)[0].strip() for question in questions_with_history - ] - last_questions = [ - question.split(": ")[-1].strip() for question in questions_with_history - ] + # questions_with_history = [ + # cut.custom["question"] for cut in batch["supervisions"]["cut"] + # ] + # history_contexts = [ + # question.rsplit(":", 1)[0].strip() for question in questions_with_history + # ] + # last_questions = [ + # question.split(": ")[-1].strip() for question in questions_with_history + # ] + # messages = [] + # for i, total_round in enumerate(chat_rounds): + # message = [] + # if total_round > 1: + # history_question_answer = history_contexts[i].split("USER:") + # history_question_answer = [item for item in history_question_answer if item] + # for j in range(total_round - 1): + # question_answer = history_question_answer[j].split("ASSISTANT:") + # message += [ + # {"role": "user", "content": question_answer[0].strip()}, + # {"role": "assistant", "content": question_answer[1].strip()}, + # ] + # message += [ + # {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, + # {"role": "assistant", "content": ""}, + # ] + # print(f"message: {message}, batch_size {len(chat_rounds)}") + # messages.append(message) messages = [] - for i, total_round in enumerate(chat_rounds): - message = [] - if total_round > 1: - history_question_answer = history_contexts[i].split("USER:") - history_question_answer = [item for item in history_question_answer if item] - for j in range(total_round - 1): - question_answer = history_question_answer[j].split("ASSISTANT:") - message += [ - {"role": "user", "content": question_answer[0].strip()}, - {"role": "assistant", "content": question_answer[1].strip()}, - ] - message += [ + for i in range(len(batch["supervisions"]["cut"])): + message = [ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, {"role": "assistant", "content": ""}, ] - print(f"message: {message}, batch_size {len(chat_rounds)}") messages.append(message) - input_ids, attention_mask = preprocess(messages, tokenizer) if params.enable_speech_output: generated_ids, generated_speech_output = model.decode_with_speech_output( @@ -478,10 +552,19 @@ def decode_one_batch( ] # WAR: only support batch = 1 for now for cut_id, audio_tokens in zip(cut_ids, generated_speech_output): speech_file_name = params.log_dir / f"{cut_id}.wav" - audio_tokens = [token for token in audio_tokens if token < 4096] + # audio_tokens = [token for token in audio_tokens if token < 4096] audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) - audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) - sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050) + if "CosyVoice2" in params.token2wav_path: + audio_hat = audio_decode_cosyvoice2( + audio_tokens, + params.prompt_text, + params.prompt_speech_path, + token2wav_model, + ) + sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000) + else: + audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) + sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050) else: generated_ids = model.decode( feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) @@ -521,18 +604,14 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - answers = batch["supervisions"]["text"] - questions_with_history = [ - cut.custom["question"] for cut in batch["supervisions"]["cut"] - ] - answer_cosyvoice_speech_token = [ - cut.custom["answer_cosyvoice_speech_token"] - for cut in batch["supervisions"]["cut"] - ] - texts = [ - question.split(": ")[-1].strip() - for question in questions_with_history - ] + texts = batch["supervisions"]["text"] + # questions_with_history = [ + # cut.custom["question"] for cut in batch["supervisions"]["cut"] + # ] + # texts = [ + # question.split(": ")[-1].strip() + # for question in questions_with_history + # ] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -636,9 +715,14 @@ def main(): logging.info(f"device: {device}") model, tokenizer = get_model(params, device) - token2wav_model = CosyVoice( - params.token2wav_path, load_jit=False, load_trt=False, fp16=False - ) + if "CosyVoice2" in params.token2wav_path: + token2wav_model = CosyVoice2( + params.token2wav_path, load_jit=False, load_trt=False, fp16=False + ) + else: + token2wav_model = CosyVoice( + params.token2wav_path, load_jit=False, load_trt=False, fp16=False + ) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -656,8 +740,9 @@ def main(): return False return True - test_sets_cuts = data_module.test_cuts() - + # TODO: FIX ME + # test_sets_cuts = data_module.test_cuts() + test_sets_cuts = data_module.test_cuts_en_vocalnet() test_sets = test_sets_cuts.keys() test_dls = [ data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))