add cosyvoice2 decode

This commit is contained in:
root 2025-05-12 10:06:59 +00:00
parent b20a0d0e35
commit 89781b9bb1
3 changed files with 155 additions and 50 deletions

View File

@ -192,3 +192,22 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
--use-flash-attn True --on-the-fly-feats True \ --use-flash-attn True --on-the-fly-feats True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "stage 11: Decoding EN, only support batch_size=1 for now."
exp_dir=./qwen_omni/exp_speech2speech_en_continue
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 997 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch4_speech2speech \
--enable-speech-output True \
--token2wav-path /workspace/CosyVoice2-0.5B \
--use-lora True
fi

View File

@ -47,9 +47,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import str2bool from utils import str2bool
class _SeedWorkers: class _SeedWorkers:
def __init__(self, seed: int): def __init__(self, seed: int):
self.seed = seed self.seed = seed
@ -457,9 +457,10 @@ class AsrDataModule:
def test_cuts_en_vocalnet(self) -> CutSet: def test_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get test cuts") logging.info("About to get test cuts")
VoiceAssistant_cuts = load_manifest_lazy( VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz" self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
) )
return VoiceAssistant_cuts return {"test": VoiceAssistant_cuts}
# def train_cuts_en_vocalnet(self) -> CutSet: # def train_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get train cuts") # logging.info("About to get train cuts")
# VoiceAssistant_cuts = load_manifest_lazy( # VoiceAssistant_cuts = load_manifest_lazy(
@ -481,4 +482,4 @@ class AsrDataModule:
# VoiceAssistant_cuts = load_manifest_lazy( # VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_debug.jsonl.gz" # self.args.manifest_dir / "cuts_debug.jsonl.gz"
# ) # )
# return VoiceAssistant_cuts # return VoiceAssistant_cuts

View File

@ -55,7 +55,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import transformers import transformers
import whisper import whisper
from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from data_module import AsrDataModule from data_module import AsrDataModule
from lhotse.cut import Cut from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector from model import SPEECH_LLM, EncoderProjector
@ -75,6 +76,57 @@ from icefall.utils import (
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_path, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
Args:
audio_tokens (list): List of audio tokens to be processed.
model_config: Configuration object containing vocab settings.
codec_decoder: Codec decoder for generating audio.
tone_dir (str): The tone directory or setting.
audio_prompt_path (str, optional): Path to the audio prompt file. Required when tone_dir is not "default_tone".
code_layer (int, optional): Number of code layers. Defaults to 1.
num_latency_tokens (int, optional): Number of latency tokens to ignore. Defaults to 0.
speed (float, optional): Speed factor for audio generation. Defaults to 1.0.
Returns:
torch.Tensor: Generated audio waveform.
"""
prompt_speech_16k = load_wav(prompt_speech_path, 16000)
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def audio_decode_cosyvoice(audio_tokens, codec_decoder): def audio_decode_cosyvoice(audio_tokens, codec_decoder):
""" """
Generate audio from tokens with optional tone and prompt embedding. Generate audio from tokens with optional tone and prompt embedding.
@ -180,7 +232,9 @@ def get_model(params, device):
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype = torch.float16 torch_dtype = torch.float16
codec_vocab_size = 4096 + 4 # TODO: FIX ME
# codec_vocab_size = 4096 + 4
codec_vocab_size = 6561 + 4
config = Qwen2Config( config = Qwen2Config(
vocab_size=codec_vocab_size, vocab_size=codec_vocab_size,
hidden_size=1024, hidden_size=1024,
@ -346,6 +400,20 @@ def get_parser():
help="The path to the token2wav model", help="The path to the token2wav model",
) )
parser.add_argument(
"--prompt_text",
type=str,
default="Romeo and Juliet might be the most famous act of William Shakespeare.",
help="The prompt text",
)
parser.add_argument(
"--prompt_speech_path",
type=str,
default="./assets/common_voice_en_2586258.wav",
help="The path to the prompt speech",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -437,36 +505,42 @@ def decode_one_batch(
2, 2,
) )
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] # chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
questions_with_history = [ # questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"] # cut.custom["question"] for cut in batch["supervisions"]["cut"]
] # ]
history_contexts = [ # history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history # question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
] # ]
last_questions = [ # last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history # question.split("<USER>: ")[-1].strip() for question in questions_with_history
] # ]
# messages = []
# for i, total_round in enumerate(chat_rounds):
# message = []
# if total_round > 1:
# history_question_answer = history_contexts[i].split("USER:")
# history_question_answer = [item for item in history_question_answer if item]
# for j in range(total_round - 1):
# question_answer = history_question_answer[j].split("ASSISTANT:")
# message += [
# {"role": "user", "content": question_answer[0].strip()},
# {"role": "assistant", "content": question_answer[1].strip()},
# ]
# message += [
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": ""},
# ]
# print(f"message: {message}, batch_size {len(chat_rounds)}")
# messages.append(message)
messages = [] messages = []
for i, total_round in enumerate(chat_rounds): for i in range(len(batch["supervisions"]["cut"])):
message = [] message = [
if total_round > 1:
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()},
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""}, {"role": "assistant", "content": ""},
] ]
print(f"message: {message}, batch_size {len(chat_rounds)}")
messages.append(message) messages.append(message)
input_ids, attention_mask = preprocess(messages, tokenizer) input_ids, attention_mask = preprocess(messages, tokenizer)
if params.enable_speech_output: if params.enable_speech_output:
generated_ids, generated_speech_output = model.decode_with_speech_output( generated_ids, generated_speech_output = model.decode_with_speech_output(
@ -478,10 +552,19 @@ def decode_one_batch(
] # WAR: only support batch = 1 for now ] # WAR: only support batch = 1 for now
for cut_id, audio_tokens in zip(cut_ids, generated_speech_output): for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
speech_file_name = params.log_dir / f"{cut_id}.wav" speech_file_name = params.log_dir / f"{cut_id}.wav"
audio_tokens = [token for token in audio_tokens if token < 4096] # audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) if "CosyVoice2" in params.token2wav_path:
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050) audio_hat = audio_decode_cosyvoice2(
audio_tokens,
params.prompt_text,
params.prompt_speech_path,
token2wav_model,
)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
else:
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
else: else:
generated_ids = model.decode( generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
@ -521,18 +604,14 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
answers = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
questions_with_history = [ # questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"] # cut.custom["question"] for cut in batch["supervisions"]["cut"]
] # ]
answer_cosyvoice_speech_token = [ # texts = [
cut.custom["answer_cosyvoice_speech_token"] # question.split("<USER>: ")[-1].strip()
for cut in batch["supervisions"]["cut"] # for question in questions_with_history
] # ]
texts = [
question.split("<USER>: ")[-1].strip()
for question in questions_with_history
]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
@ -636,9 +715,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
model, tokenizer = get_model(params, device) model, tokenizer = get_model(params, device)
token2wav_model = CosyVoice( if "CosyVoice2" in params.token2wav_path:
params.token2wav_path, load_jit=False, load_trt=False, fp16=False token2wav_model = CosyVoice2(
) params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
else:
token2wav_model = CosyVoice(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -656,8 +740,9 @@ def main():
return False return False
return True return True
test_sets_cuts = data_module.test_cuts() # TODO: FIX ME
# test_sets_cuts = data_module.test_cuts()
test_sets_cuts = data_module.test_cuts_en_vocalnet()
test_sets = test_sets_cuts.keys() test_sets = test_sets_cuts.keys()
test_dls = [ test_dls = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt)) data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))