add cosyvoice2 decode

This commit is contained in:
root 2025-05-12 10:06:59 +00:00
parent b20a0d0e35
commit 89781b9bb1
3 changed files with 155 additions and 50 deletions

View File

@ -192,3 +192,22 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
--use-flash-attn True --on-the-fly-feats True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "stage 11: Decoding EN, only support batch_size=1 for now."
exp_dir=./qwen_omni/exp_speech2speech_en_continue
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 997 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch4_speech2speech \
--enable-speech-output True \
--token2wav-path /workspace/CosyVoice2-0.5B \
--use-lora True
fi

View File

@ -47,9 +47,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader
from utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
@ -457,9 +457,10 @@ class AsrDataModule:
def test_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get test cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
)
return VoiceAssistant_cuts
return {"test": VoiceAssistant_cuts}
# def train_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get train cuts")
# VoiceAssistant_cuts = load_manifest_lazy(

View File

@ -55,7 +55,8 @@ import torch
import torch.nn as nn
import transformers
import whisper
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from data_module import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
@ -75,6 +76,57 @@ from icefall.utils import (
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_path, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
Args:
audio_tokens (list): List of audio tokens to be processed.
model_config: Configuration object containing vocab settings.
codec_decoder: Codec decoder for generating audio.
tone_dir (str): The tone directory or setting.
audio_prompt_path (str, optional): Path to the audio prompt file. Required when tone_dir is not "default_tone".
code_layer (int, optional): Number of code layers. Defaults to 1.
num_latency_tokens (int, optional): Number of latency tokens to ignore. Defaults to 0.
speed (float, optional): Speed factor for audio generation. Defaults to 1.0.
Returns:
torch.Tensor: Generated audio waveform.
"""
prompt_speech_16k = load_wav(prompt_speech_path, 16000)
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
"""
Generate audio from tokens with optional tone and prompt embedding.
@ -180,7 +232,9 @@ def get_model(params, device):
attn_implementation = "eager"
torch_dtype = torch.float16
codec_vocab_size = 4096 + 4
# TODO: FIX ME
# codec_vocab_size = 4096 + 4
codec_vocab_size = 6561 + 4
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
@ -346,6 +400,20 @@ def get_parser():
help="The path to the token2wav model",
)
parser.add_argument(
"--prompt_text",
type=str,
default="Romeo and Juliet might be the most famous act of William Shakespeare.",
help="The prompt text",
)
parser.add_argument(
"--prompt_speech_path",
type=str,
default="./assets/common_voice_en_2586258.wav",
help="The path to the prompt speech",
)
add_model_arguments(parser)
return parser
@ -437,36 +505,42 @@ def decode_one_batch(
2,
)
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
# chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history
]
# questions_with_history = [
# cut.custom["question"] for cut in batch["supervisions"]["cut"]
# ]
# history_contexts = [
# question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
# ]
# last_questions = [
# question.split("<USER>: ")[-1].strip() for question in questions_with_history
# ]
# messages = []
# for i, total_round in enumerate(chat_rounds):
# message = []
# if total_round > 1:
# history_question_answer = history_contexts[i].split("USER:")
# history_question_answer = [item for item in history_question_answer if item]
# for j in range(total_round - 1):
# question_answer = history_question_answer[j].split("ASSISTANT:")
# message += [
# {"role": "user", "content": question_answer[0].strip()},
# {"role": "assistant", "content": question_answer[1].strip()},
# ]
# message += [
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": ""},
# ]
# print(f"message: {message}, batch_size {len(chat_rounds)}")
# messages.append(message)
messages = []
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()},
]
message += [
for i in range(len(batch["supervisions"]["cut"])):
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
print(f"message: {message}, batch_size {len(chat_rounds)}")
messages.append(message)
input_ids, attention_mask = preprocess(messages, tokenizer)
if params.enable_speech_output:
generated_ids, generated_speech_output = model.decode_with_speech_output(
@ -478,8 +552,17 @@ def decode_one_batch(
] # WAR: only support batch = 1 for now
for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
speech_file_name = params.log_dir / f"{cut_id}.wav"
audio_tokens = [token for token in audio_tokens if token < 4096]
# audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
if "CosyVoice2" in params.token2wav_path:
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
params.prompt_text,
params.prompt_speech_path,
token2wav_model,
)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
else:
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
else:
@ -521,18 +604,14 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
answers = batch["supervisions"]["text"]
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
answer_cosyvoice_speech_token = [
cut.custom["answer_cosyvoice_speech_token"]
for cut in batch["supervisions"]["cut"]
]
texts = [
question.split("<USER>: ")[-1].strip()
for question in questions_with_history
]
texts = batch["supervisions"]["text"]
# questions_with_history = [
# cut.custom["question"] for cut in batch["supervisions"]["cut"]
# ]
# texts = [
# question.split("<USER>: ")[-1].strip()
# for question in questions_with_history
# ]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
@ -636,6 +715,11 @@ def main():
logging.info(f"device: {device}")
model, tokenizer = get_model(params, device)
if "CosyVoice2" in params.token2wav_path:
token2wav_model = CosyVoice2(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
else:
token2wav_model = CosyVoice(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
@ -656,8 +740,9 @@ def main():
return False
return True
test_sets_cuts = data_module.test_cuts()
# TODO: FIX ME
# test_sets_cuts = data_module.test_cuts()
test_sets_cuts = data_module.test_cuts_en_vocalnet()
test_sets = test_sets_cuts.keys()
test_dls = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))