support text2speech ultrachat

This commit is contained in:
root 2025-06-02 23:16:03 -07:00
parent 49256fa917
commit 4c0396f8f2
4 changed files with 224 additions and 82 deletions

View File

@ -3,7 +3,7 @@
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
# export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
set -eou pipefail
stage=$1
@ -123,7 +123,9 @@ fi
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "stage 19: Training TTS Model"
exp_dir=./qwen_omni/exp_tts
exp_dir=./qwen_omni/exp_tts_ultra_chat_voice_assistant
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_only_template
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_three_concat
pretrained_dir=./qwen_omni/exp_speech2text
ngpu=4
@ -141,17 +143,16 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
fi
done
fi
train_cmd_args="--batch-size 64 \
# --dataset ultra_chat_voice_assistant
train_cmd_args="--batch-size 30 \
--exp-dir $exp_dir \
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--enable-speech-input False \
--deepspeed \
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/emilia_en \
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/VoxBox/manifests_emilia_en \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--num-epochs 2 \
--num-epochs 3 \
--use-lora False --unfreeze-llm False --enable-speech-output True"
if [ "$latest_checkpoint_step" -ge 0 ]; then
@ -168,66 +169,66 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
fi
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "stage 20: Training TTS Model"
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
if [ ! -L "/workspace/slam" ]; then
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
fi
exp_dir=./qwen_omni/exp_test
ngpu=4
# if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
# log "stage 20: Training TTS Model"
# echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
# if [ ! -L "/workspace/slam" ]; then
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
# fi
# exp_dir=./qwen_omni/exp_test
# ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
# latest_checkpoint_step=-1
# # Check if exp_dir exists and is a directory
# if [ -d "$exp_dir" ]; then
# # List directories matching checkpoint-* and find the one with the largest step number
# for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
# checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# # Extract step number using parameter expansion
# current_step=${checkpoint_name#checkpoint-}
# # Ensure current_step is a number
# if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
# latest_checkpoint_step=$current_step
# fi
# done
# fi
train_cmd_args="--max-duration 150 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--dataset vocalnet_ultrachat_voiceassistant \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True --on-the-fly-feats True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
# train_cmd_args="--max-duration 150 \
# --enable-musan False \
# --exp-dir $exp_dir \
# --speech-encoder-path-or-name models/large-v2.pt \
# --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
# --dataset vocalnet_ultrachat_voiceassistant \
# --manifest-dir data/fbank \
# --deepspeed \
# --deepspeed_config ./qwen_omni/ds_config_zero1.json \
# --use-flash-attn True --on-the-fly-feats True \
# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
# if [ "$latest_checkpoint_step" -ge 0 ]; then
# log "Continuing training from checkpoint-$latest_checkpoint_step"
# step=$latest_checkpoint_step
# train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
# else
# log "Starting training from scratch as no checkpoint was found in $exp_dir"
# # No pretrained model or sampler state dict needed for the first run
# fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
$train_cmd_args
fi
# torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
# $train_cmd_args
# fi
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
log "stage 21: TTS Decoding Test Set"
exp_dir=./qwen_omni/exp_tts
torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output True \
--token2wav-path /workspace/CosyVoice2-0.5B \
--use-lora True
fi
# if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
# log "stage 21: TTS Decoding Test Set"
# exp_dir=./qwen_omni/exp_tts
# torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
# --exp-dir $exp_dir \
# --speech-encoder-path-or-name models/large-v2.pt \
# --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
# --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
# --use-flash-attn True \
# --enable-speech-output True \
# --token2wav-path /workspace/CosyVoice2-0.5B \
# --use-lora True
# fi

View File

@ -437,7 +437,8 @@ class SPEECH_LLM(nn.Module):
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
# text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
text_input_embeds_list = []
for i in range(len(text_label_start_index_list)):
text_last_hidden = model_outputs.hidden_states[-1][
i,
@ -445,14 +446,14 @@ class SPEECH_LLM(nn.Module):
+ input_seq_len[i]
- 1,
]
text_last_hidden_lists.append(text_last_hidden)
# text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[
i,
text_input_start_index_list[i]
+ 1 : text_input_start_index_list[i]
+ input_seq_len[i],
] # exclude bos
text_embeds_list.append(text_embed)
# text_embeds_list.append(text_embed)
text_input_embeds = torch.cat(
[
@ -480,8 +481,9 @@ class SPEECH_LLM(nn.Module):
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
logging.warning(
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}\naudio_codes_lens: {audio_codes_lens[i]}\ninput_question_len_list: {input_question_len_list[i]}\ninput_seq_len: {input_seq_len[i]}\n"
)
# breakpoint()
text_input_embeds = text_input_embeds[
: audio_embeddings.shape[1] - start_idx
]

View File

@ -68,6 +68,7 @@ from transformers import (
)
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader
from pathlib import Path
from train import add_model_arguments, add_training_arguments, get_params, get_model
from utils import ( # filter_uneven_sized_batch,
@ -171,11 +172,14 @@ def data_collator(batch):
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": item["text"]},
]
# message_list_item += [
# {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": item["text"]},
# ]
messages.append(message_list_item)
durations.append(item["duration"])
ids.append(item["id"])
ids.append(item["index"] if "index" in item else item["id"])
lang.append(item["language"])
dnsmos.append(item["dnsmos"])
return {
"speech_tokens": speech_tokens,
@ -183,7 +187,92 @@ def data_collator(batch):
"durations": durations,
"ids": ids,
"lang": lang,
"dnsmos": dnsmos,
}
def data_collator_concate_items(batch, concat_items_num: int = 3):
"""Concatenate *concat_items_num* consecutive dataset items into one.
The function groups the incoming ``batch`` (a list of dataset items)
into non-overlapping chunks of *concat_items_num*. For each group it
concatenates the textual fields and speech codec tokens so that the
model generates one longer utterance instead of several short ones.
Any remainder (when ``len(batch)`` is not divisible by
*concat_items_num*) is also kept as a smaller group.
"""
grouped_speech_tokens, grouped_messages, grouped_durations = [], [], []
grouped_ids, grouped_lang = [], []
# Iterate over the batch in strides of *concat_items_num*
for start_idx in range(0, len(batch), concat_items_num):
group = batch[start_idx : start_idx + concat_items_num]
if not group:
continue
# 1) Speech tokens --------------------------------------------------
# ``item['code']`` can be a list[int] or a 1-D tensor. Use the first
# element to decide how to concatenate.
first_code = group[0]["code"]
if isinstance(first_code, torch.Tensor):
concat_code = torch.cat([item["code"] for item in group], dim=0)
else:
# assume list / iterable of ints
concat_code = []
for item in group:
concat_code.extend(item["code"])
# 2) Text -----------------------------------------------------------
concat_text = "".join([item["text"] for item in group])
# 3) Build chat template messages -----------------------------------
message_list_item = [
{
"role": "user",
"content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": concat_text},
]
# 4) Misc meta fields ----------------------------------------------
total_duration = sum(item["duration"] for item in group)
group_ids = [item.get("index", item.get("id")) for item in group]
language = group[0].get("language", "")
# 5) Append to output lists ----------------------------------------
grouped_speech_tokens.append(concat_code)
grouped_messages.append(message_list_item)
grouped_durations.append(total_duration)
grouped_ids.append(group_ids)
grouped_lang.append(language)
return {
"speech_tokens": grouped_speech_tokens,
"messages": grouped_messages,
"durations": grouped_durations,
"ids": grouped_ids,
"lang": grouped_lang,
}
def data_collator_ultra_chat(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch):
speech_tokens.append(item["custom"]["speech_token"])
text = item["supervisions"][0]["text"]
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": text},
]
messages.append(message_list_item)
durations.append(item["duration"])
ids.append(item["id"])
return {
"speech_tokens": speech_tokens,
"messages": messages,
"durations": durations,
"ids": ids,
}
def compute_loss(
@ -470,13 +559,21 @@ def run(rank, world_size, args):
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
# print(params.dataset)
ds = load_dataset(params.dataset, split="train")
# shuffle the dataset
ds = ds.shuffle(seed=42)
train_test_split = ds.train_test_split(test_size=1000, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
# train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"]
if params.dataset == "ultra_chat_voice_assistant":
data_dir = "data/fbank"
json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"]
ds = load_dataset("json", data_files=json_file_lists, split="train")
# shuffle the dataset
train_dataset = ds.shuffle(seed=42)
eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train")
else:
data_dir = Path(params.dataset)
json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")]
ds = load_dataset("json", data_files=json_file_lists, split="train")
# shuffle the dataset
ds = ds.shuffle(seed=42)
train_test_split = ds.train_test_split(test_size=1000, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dl = StatefulDataLoader(
@ -486,7 +583,7 @@ def run(rank, world_size, args):
shuffle=False,
num_workers=4,
prefetch_factor=2,
collate_fn=data_collator
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
)
train_dl.load_state_dict(sampler_state_dict)
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
@ -497,7 +594,7 @@ def run(rank, world_size, args):
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=data_collator
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
)
if args.tensorboard and rank == 0:

View File

@ -12,12 +12,12 @@ from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
from tqdm import tqdm
import kaldialign
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import numpy as np
Pathlike = Union[str, Path]
@ -431,3 +431,45 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
def get_dataset_statistics(dataset, save_filename):
speech_token_lengths = []
text_lengths = []
for item in tqdm(dataset):
if 'custom' not in item:
speech_token = item["code"]
text = item["text"]
else:
speech_token = item["custom"]["speech_token"]
text = item["supervisions"][0]["text"]
speech_token_lengths.append(len(speech_token))
text_lengths.append(len(text))
speech_token_length_array = np.array(speech_token_lengths)
text_length_array = np.array(text_lengths)
# 计算并存储统计指标
def get_length_stats(lengths_array):
length_stats = []
length_stats.append(["count", f"{len(lengths_array)}"]) # 总数
length_stats.append(["mean", f"{np.mean(lengths_array):.1f}"])
length_stats.append(["std", f"{np.std(lengths_array):.1f}"])
length_stats.append(["min", f"{np.min(lengths_array):.1f}"])
length_stats.append(["25%", f"{np.percentile(lengths_array, 25):.1f}"])
length_stats.append(["50% (median)", f"{np.median(lengths_array):.1f}"]) # median 和 50% percentile 是一样的
length_stats.append(["75%", f"{np.percentile(lengths_array, 75):.1f}"])
length_stats.append(["99%", f"{np.percentile(lengths_array, 99):.1f}"])
length_stats.append(["99.5%", f"{np.percentile(lengths_array, 99.5):.1f}"])
length_stats.append(["99.9%", f"{np.percentile(lengths_array, 99.9):.1f}"])
length_stats.append(["max", f"{np.max(lengths_array):.1f}"])
return length_stats
speech_length_stats = get_length_stats(speech_token_length_array)
text_length_stats = get_length_stats(text_length_array)
with open(save_filename, "w") as f:
print("speech_tokens 长度统计指标:", file=f)
for stat_name, stat_value in speech_length_stats:
print(f"{stat_name:<15}: {stat_value}", file=f)
print("\ntext 长度统计指标:", file=f)
for stat_name, stat_value in text_length_stats:
print(f"{stat_name:<15}: {stat_value}", file=f)
return speech_token_lengths, text_lengths