add tts task decode

This commit is contained in:
root 2025-05-27 02:12:22 -07:00
parent 1281d7a515
commit 5a7c72cb47
5 changed files with 558 additions and 31 deletions

View File

@ -0,0 +1,233 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=$1
stop_stage=$2
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
if [ ! -L "/workspace/slam" ]; then
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
fi
log "stage 17: Training Speech2Speech Model, full parameters"
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
pretrained_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--max-duration 200 \
--enable-musan False \
--exp-dir $exp_dir \
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True --on-the-fly-speed-perturb False\
--deepspeed \
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True --on-the-fly-feats True \
--dataset vocalnet_ultrachat_voiceassistant_instruct_s2s --num-epochs 10 \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
$train_cmd_args
fi
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
# check if the link exists, if not exist, create it
if [ ! -L "/workspace/slam" ]; then
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
fi
log "stage 17: Training Speech2Speech Model, full parameters"
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
pretrained_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--max-duration 200 \
--enable-musan False \
--exp-dir $exp_dir \
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True --on-the-fly-speed-perturb False\
--deepspeed \
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True --on-the-fly-feats True \
--dataset vocalnet_ultrachat_voiceassistant_instruct_s2s_librispeech --num-epochs 10 \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
$train_cmd_args
fi
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "stage 19: Training TTS Model"
exp_dir=./qwen_omni/exp_tts
pretrained_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--batch-size 64 \
--exp-dir $exp_dir \
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--enable-speech-input False \
--deepspeed \
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/emilia_en \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--num-epochs 2 \
--use-lora False --unfreeze-llm False --enable-speech-output True"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train_tts.py \
$train_cmd_args
fi
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "stage 20: Training TTS Model"
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
if [ ! -L "/workspace/slam" ]; then
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
fi
exp_dir=./qwen_omni/exp_test
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--max-duration 150 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--dataset vocalnet_ultrachat_voiceassistant \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True --on-the-fly-feats True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
$train_cmd_args
fi
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
log "stage 21: TTS Decoding Test Set"
exp_dir=./qwen_omni/exp_tts
torchrun --nproc_per_node=4 python3 ./qwen_omni/decode_tts.py \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output True \
--token2wav-path /lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice2-0.5B \
--use-lora True
fi

View File

@ -242,9 +242,13 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
# exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
# The final assignment of datasets in the original script is used here:
# (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu")
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh")
declare -a target_datasets=("mmsu")
NUM_CLIENT_JOBS=4 # Number of parallel client jobs
BASE_PORT=8000 # Base port for servers
@ -367,6 +371,8 @@ if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
N_GPUS=4 # Define the number of GPUs/processes you want to launch
@ -376,10 +382,10 @@ if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
CUDA_VISIBLE_DEVICES=$id python3 ./qwen_omni/server.py \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
--checkpoint-path $exp_dir/checkpoint-55276/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output False \
--port $(expr 8000 + $id) \
--port $(expr 18000 + $id) \
--use-lora True &
done

View File

@ -77,7 +77,7 @@ sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_path, codec_decoder
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
@ -95,7 +95,6 @@ def audio_decode_cosyvoice2(
Returns:
torch.Tensor: Generated audio waveform.
"""
prompt_speech_16k = load_wav(prompt_speech_path, 16000)
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
@ -555,10 +554,11 @@ def decode_one_batch(
# audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
if "CosyVoice2" in params.token2wav_path:
prompt_speech_16k = load_wav(params.prompt_speech_path, 16000)
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
params.prompt_text,
params.prompt_speech_path,
prompt_speech_16k,
token2wav_model,
)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)

View File

@ -0,0 +1,294 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from datasets import load_dataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader
from train import add_model_arguments, add_training_arguments, get_params, get_model
from decode import audio_decode_cosyvoice2
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
get_local_rank,
get_rank,
get_world_size,
setup_logger,
str2bool,
)
from cosyvoice.cli.cosyvoice import CosyVoice2
sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
DEFAULT_SPEECH_TOKEN = "<speech>"
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="The batch size to use.",
)
parser.add_argument(
"--split-name",
type=str,
default="test_en",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument(
"--token2wav-path",
type=str,
default="/workspace/CosyVoice-300M-SFT",
help="The path to the token2wav model",
)
add_model_arguments(parser)
return parser
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
# WAR: TODO FIXME check qwen3
# THIS IS THE ONLY DIFFERENCE FROM preprocess
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
target_ids[row, col] = default_speech_token_id
# remove default_speech_token_id from target_ids and input_ids
batch_size = target_ids.size(0)
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def data_collator(batch):
prompt_texts, prompt_speech_16k, messages, ids = [], [], [], []
for i, item in enumerate(batch):
# speech_tokens.append(item["prompt_audio_cosy2_tokens"])
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
messages.append(message_list_item)
ids.append(item["id"])
prompt_texts.append(item["prompt_text"])
prompt_speech_16k.append(item["prompt_audio"])
print(item["prompt_audio"], 233333333333333333)
return {
"prompt_texts": prompt_texts,
"prompt_speech_16k": prompt_speech_16k,
"messages": messages,
"ids": ids,
}
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
params.log_dir = Path(params.exp_dir) / f"log-results-wav"
params.log_dir.mkdir(parents=True, exist_ok=True)
fix_random_seed(params.seed)
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-decode-tts")
logging.info(params)
logging.info("About to create model")
model, tokenizer = get_model(params)
if torch.cuda.is_available():
device = torch.device("cuda", get_local_rank())
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
data_loader = DataLoader(
dataset,
batch_size=params.batch_size,
sampler=sampler,
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=data_collator
)
token2wav_model = CosyVoice2(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
for batch in data_loader:
messages = batch["messages"]
prompt_texts = batch["prompt_texts"]
prompt_speech_16k = batch["prompt_speech_16k"]
ids = batch["ids"]
input_ids, attention_mask, _ = preprocess(messages, tokenizer)
generated_ids, generated_speech_output = model.decode_with_speech_output(
None, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
generated_speech_output = [
generated_speech_output
] # WAR: only support batch = 1 for now
for cut_id, audio_tokens, prompt_text, prompt_speech in zip(ids, generated_speech_output, prompt_texts, prompt_speech_16k):
speech_file_name = params.log_dir / f"{cut_id}.wav"
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
if "CosyVoice2" in params.token2wav_path:
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
prompt_text,
prompt_speech,
token2wav_model,
)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
logging.info("Done!")
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = get_world_size()
rank = get_rank()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()

View File

@ -479,12 +479,12 @@ class SPEECH_LLM(nn.Module):
start_idx == start_idx_re_compute
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
text_input_embeds = text_input_embeds[
: audio_embeddings.shape[1] - start_idx
]
logging.warning(
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
)
text_input_embeds = text_input_embeds[
: audio_embeddings.shape[1] - start_idx
]
audio_embeddings[
i, start_idx : start_idx + text_input_embeds.shape[0]
] += text_input_embeds
@ -592,27 +592,18 @@ class SPEECH_LLM(nn.Module):
- generated_speech_tokens: List of lists, where each inner list contains
the generated speech codec tokens for a batch item.
"""
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
if (
not self.codec_lm
or not self.speech_token_projector
or not self.codec_lm_head
):
raise ValueError(
"codec_lm and associated layers must be initialized to generate speech output."
)
batch_size = input_ids.shape[0]
assert batch_size == 1, "Batch size must be 1 for speech generation."
device = next(self.parameters()).device # Use model's device
batch_size = fbank.shape[0]
# --- 1. Prepare Prompt Embeddings ---
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
# Merge speech features with prompt embeddings
if fbank is not None:
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
(
merged_prompt_inputs_embeds,
merged_prompt_attention_mask,
@ -621,6 +612,9 @@ class SPEECH_LLM(nn.Module):
) = self._merge_input_ids_with_speech_features(
speech_features, prompt_embeds, input_ids, attention_mask
)
else:
merged_prompt_inputs_embeds = prompt_embeds
merged_prompt_attention_mask = attention_mask
# --- 2. Generate Text using LLM ---
# Use merged embeds/mask as input to generate