support text2speech ultrachat

This commit is contained in:
root 2025-06-02 23:16:03 -07:00
parent 49256fa917
commit 4c0396f8f2
4 changed files with 224 additions and 82 deletions

View File

@ -3,7 +3,7 @@
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface" # export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
set -eou pipefail set -eou pipefail
stage=$1 stage=$1
@ -123,7 +123,9 @@ fi
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "stage 19: Training TTS Model" log "stage 19: Training TTS Model"
exp_dir=./qwen_omni/exp_tts exp_dir=./qwen_omni/exp_tts_ultra_chat_voice_assistant
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_only_template
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_three_concat
pretrained_dir=./qwen_omni/exp_speech2text pretrained_dir=./qwen_omni/exp_speech2text
ngpu=4 ngpu=4
@ -141,17 +143,16 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
fi fi
done done
fi fi
# --dataset ultra_chat_voice_assistant
train_cmd_args="--batch-size 64 \ train_cmd_args="--batch-size 30 \
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--enable-speech-input False \ --enable-speech-input False \
--deepspeed \ --deepspeed \
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/emilia_en \ --dataset /lustre/fsw/general_sa/yuekaiz/s2s/VoxBox/manifests_emilia_en \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--num-epochs 2 \ --num-epochs 3 \
--use-lora False --unfreeze-llm False --enable-speech-output True" --use-lora False --unfreeze-llm False --enable-speech-output True"
if [ "$latest_checkpoint_step" -ge 0 ]; then if [ "$latest_checkpoint_step" -ge 0 ]; then
@ -168,66 +169,66 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
fi fi
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then # if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "stage 20: Training TTS Model" # log "stage 20: Training TTS Model"
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -" # echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
if [ ! -L "/workspace/slam" ]; then # if [ ! -L "/workspace/slam" ]; then
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd - # cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
fi # fi
exp_dir=./qwen_omni/exp_test # exp_dir=./qwen_omni/exp_test
ngpu=4 # ngpu=4
latest_checkpoint_step=-1 # latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory # # Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then # if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number # # List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do # for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000 # checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion # # Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-} # current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number # # Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then # if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step # latest_checkpoint_step=$current_step
fi # fi
done # done
fi # fi
train_cmd_args="--max-duration 150 \ # train_cmd_args="--max-duration 150 \
--enable-musan False \ # --enable-musan False \
--exp-dir $exp_dir \ # --exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \ # --speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ # --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--dataset vocalnet_ultrachat_voiceassistant \ # --dataset vocalnet_ultrachat_voiceassistant \
--manifest-dir data/fbank \ # --manifest-dir data/fbank \
--deepspeed \ # --deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \ # --deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True --on-the-fly-feats True \ # --use-flash-attn True --on-the-fly-feats True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True" # --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
if [ "$latest_checkpoint_step" -ge 0 ]; then # if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step" # log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step # step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt" # train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else # else
log "Starting training from scratch as no checkpoint was found in $exp_dir" # log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run # # No pretrained model or sampler state dict needed for the first run
fi # fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \ # torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
$train_cmd_args # $train_cmd_args
fi # fi
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then # if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
log "stage 21: TTS Decoding Test Set" # log "stage 21: TTS Decoding Test Set"
exp_dir=./qwen_omni/exp_tts # exp_dir=./qwen_omni/exp_tts
torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \ # torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
--exp-dir $exp_dir \ # --exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \ # --speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ # --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \ # --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
--use-flash-attn True \ # --use-flash-attn True \
--enable-speech-output True \ # --enable-speech-output True \
--token2wav-path /workspace/CosyVoice2-0.5B \ # --token2wav-path /workspace/CosyVoice2-0.5B \
--use-lora True # --use-lora True
fi # fi

View File

@ -437,7 +437,8 @@ class SPEECH_LLM(nn.Module):
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], [] # text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
text_input_embeds_list = []
for i in range(len(text_label_start_index_list)): for i in range(len(text_label_start_index_list)):
text_last_hidden = model_outputs.hidden_states[-1][ text_last_hidden = model_outputs.hidden_states[-1][
i, i,
@ -445,14 +446,14 @@ class SPEECH_LLM(nn.Module):
+ input_seq_len[i] + input_seq_len[i]
- 1, - 1,
] ]
text_last_hidden_lists.append(text_last_hidden) # text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[ text_embed = inputs_embeds[
i, i,
text_input_start_index_list[i] text_input_start_index_list[i]
+ 1 : text_input_start_index_list[i] + 1 : text_input_start_index_list[i]
+ input_seq_len[i], + input_seq_len[i],
] # exclude bos ] # exclude bos
text_embeds_list.append(text_embed) # text_embeds_list.append(text_embed)
text_input_embeds = torch.cat( text_input_embeds = torch.cat(
[ [
@ -480,8 +481,9 @@ class SPEECH_LLM(nn.Module):
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}" ), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx: if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
logging.warning( logging.warning(
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}" f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}\naudio_codes_lens: {audio_codes_lens[i]}\ninput_question_len_list: {input_question_len_list[i]}\ninput_seq_len: {input_seq_len[i]}\n"
) )
# breakpoint()
text_input_embeds = text_input_embeds[ text_input_embeds = text_input_embeds[
: audio_embeddings.shape[1] - start_idx : audio_embeddings.shape[1] - start_idx
] ]

View File

@ -68,6 +68,7 @@ from transformers import (
) )
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader from torch.utils.data import DistributedSampler, DataLoader
from pathlib import Path
from train import add_model_arguments, add_training_arguments, get_params, get_model from train import add_model_arguments, add_training_arguments, get_params, get_model
from utils import ( # filter_uneven_sized_batch, from utils import ( # filter_uneven_sized_batch,
@ -171,11 +172,14 @@ def data_collator(batch):
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"}, {"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": item["text"]}, {"role": "assistant", "content": item["text"]},
] ]
# message_list_item += [
# {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": item["text"]},
# ]
messages.append(message_list_item) messages.append(message_list_item)
durations.append(item["duration"]) durations.append(item["duration"])
ids.append(item["id"]) ids.append(item["index"] if "index" in item else item["id"])
lang.append(item["language"]) lang.append(item["language"])
dnsmos.append(item["dnsmos"])
return { return {
"speech_tokens": speech_tokens, "speech_tokens": speech_tokens,
@ -183,7 +187,92 @@ def data_collator(batch):
"durations": durations, "durations": durations,
"ids": ids, "ids": ids,
"lang": lang, "lang": lang,
"dnsmos": dnsmos, }
def data_collator_concate_items(batch, concat_items_num: int = 3):
"""Concatenate *concat_items_num* consecutive dataset items into one.
The function groups the incoming ``batch`` (a list of dataset items)
into non-overlapping chunks of *concat_items_num*. For each group it
concatenates the textual fields and speech codec tokens so that the
model generates one longer utterance instead of several short ones.
Any remainder (when ``len(batch)`` is not divisible by
*concat_items_num*) is also kept as a smaller group.
"""
grouped_speech_tokens, grouped_messages, grouped_durations = [], [], []
grouped_ids, grouped_lang = [], []
# Iterate over the batch in strides of *concat_items_num*
for start_idx in range(0, len(batch), concat_items_num):
group = batch[start_idx : start_idx + concat_items_num]
if not group:
continue
# 1) Speech tokens --------------------------------------------------
# ``item['code']`` can be a list[int] or a 1-D tensor. Use the first
# element to decide how to concatenate.
first_code = group[0]["code"]
if isinstance(first_code, torch.Tensor):
concat_code = torch.cat([item["code"] for item in group], dim=0)
else:
# assume list / iterable of ints
concat_code = []
for item in group:
concat_code.extend(item["code"])
# 2) Text -----------------------------------------------------------
concat_text = "".join([item["text"] for item in group])
# 3) Build chat template messages -----------------------------------
message_list_item = [
{
"role": "user",
"content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": concat_text},
]
# 4) Misc meta fields ----------------------------------------------
total_duration = sum(item["duration"] for item in group)
group_ids = [item.get("index", item.get("id")) for item in group]
language = group[0].get("language", "")
# 5) Append to output lists ----------------------------------------
grouped_speech_tokens.append(concat_code)
grouped_messages.append(message_list_item)
grouped_durations.append(total_duration)
grouped_ids.append(group_ids)
grouped_lang.append(language)
return {
"speech_tokens": grouped_speech_tokens,
"messages": grouped_messages,
"durations": grouped_durations,
"ids": grouped_ids,
"lang": grouped_lang,
}
def data_collator_ultra_chat(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch):
speech_tokens.append(item["custom"]["speech_token"])
text = item["supervisions"][0]["text"]
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": text},
]
messages.append(message_list_item)
durations.append(item["duration"])
ids.append(item["id"])
return {
"speech_tokens": speech_tokens,
"messages": messages,
"durations": durations,
"ids": ids,
} }
def compute_loss( def compute_loss(
@ -470,13 +559,21 @@ def run(rank, world_size, args):
sampler_state_dict = None sampler_state_dict = None
if params.sampler_state_dict_path: if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict = torch.load(params.sampler_state_dict_path)
# print(params.dataset) if params.dataset == "ultra_chat_voice_assistant":
ds = load_dataset(params.dataset, split="train") data_dir = "data/fbank"
# shuffle the dataset json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"]
ds = ds.shuffle(seed=42) ds = load_dataset("json", data_files=json_file_lists, split="train")
train_test_split = ds.train_test_split(test_size=1000, seed=42) # shuffle the dataset
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] train_dataset = ds.shuffle(seed=42)
# train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"] eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train")
else:
data_dir = Path(params.dataset)
json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")]
ds = load_dataset("json", data_files=json_file_lists, split="train")
# shuffle the dataset
ds = ds.shuffle(seed=42)
train_test_split = ds.train_test_split(test_size=1000, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dl = StatefulDataLoader( train_dl = StatefulDataLoader(
@ -486,7 +583,7 @@ def run(rank, world_size, args):
shuffle=False, shuffle=False,
num_workers=4, num_workers=4,
prefetch_factor=2, prefetch_factor=2,
collate_fn=data_collator collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
) )
train_dl.load_state_dict(sampler_state_dict) train_dl.load_state_dict(sampler_state_dict)
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank) valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
@ -497,7 +594,7 @@ def run(rank, world_size, args):
shuffle=False, shuffle=False,
num_workers=1, num_workers=1,
prefetch_factor=1, prefetch_factor=1,
collate_fn=data_collator collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
) )
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:

View File

@ -12,12 +12,12 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
from tqdm import tqdm
import kaldialign import kaldialign
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import numpy as np
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
@ -431,3 +431,45 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate) return float(tot_err_rate)
def get_dataset_statistics(dataset, save_filename):
speech_token_lengths = []
text_lengths = []
for item in tqdm(dataset):
if 'custom' not in item:
speech_token = item["code"]
text = item["text"]
else:
speech_token = item["custom"]["speech_token"]
text = item["supervisions"][0]["text"]
speech_token_lengths.append(len(speech_token))
text_lengths.append(len(text))
speech_token_length_array = np.array(speech_token_lengths)
text_length_array = np.array(text_lengths)
# 计算并存储统计指标
def get_length_stats(lengths_array):
length_stats = []
length_stats.append(["count", f"{len(lengths_array)}"]) # 总数
length_stats.append(["mean", f"{np.mean(lengths_array):.1f}"])
length_stats.append(["std", f"{np.std(lengths_array):.1f}"])
length_stats.append(["min", f"{np.min(lengths_array):.1f}"])
length_stats.append(["25%", f"{np.percentile(lengths_array, 25):.1f}"])
length_stats.append(["50% (median)", f"{np.median(lengths_array):.1f}"]) # median 和 50% percentile 是一样的
length_stats.append(["75%", f"{np.percentile(lengths_array, 75):.1f}"])
length_stats.append(["99%", f"{np.percentile(lengths_array, 99):.1f}"])
length_stats.append(["99.5%", f"{np.percentile(lengths_array, 99.5):.1f}"])
length_stats.append(["99.9%", f"{np.percentile(lengths_array, 99.9):.1f}"])
length_stats.append(["max", f"{np.max(lengths_array):.1f}"])
return length_stats
speech_length_stats = get_length_stats(speech_token_length_array)
text_length_stats = get_length_stats(text_length_array)
with open(save_filename, "w") as f:
print("speech_tokens 长度统计指标:", file=f)
for stat_name, stat_value in speech_length_stats:
print(f"{stat_name:<15}: {stat_value}", file=f)
print("\ntext 长度统计指标:", file=f)
for stat_name, stat_value in text_length_stats:
print(f"{stat_name:<15}: {stat_value}", file=f)
return speech_token_lengths, text_lengths