diff --git a/egs/speech_llm/SPEECH2SPEECH/exp.sh b/egs/speech_llm/SPEECH2SPEECH/exp.sh index 03461b97b..26b2c8745 100644 --- a/egs/speech_llm/SPEECH2SPEECH/exp.sh +++ b/egs/speech_llm/SPEECH2SPEECH/exp.sh @@ -3,7 +3,7 @@ # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice -export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface" +# export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface" set -eou pipefail stage=$1 @@ -123,7 +123,9 @@ fi if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then log "stage 19: Training TTS Model" - exp_dir=./qwen_omni/exp_tts + exp_dir=./qwen_omni/exp_tts_ultra_chat_voice_assistant + exp_dir=./qwen_omni/exp_tts_emilia_en_tts_only_template + exp_dir=./qwen_omni/exp_tts_emilia_en_tts_three_concat pretrained_dir=./qwen_omni/exp_speech2text ngpu=4 @@ -141,17 +143,16 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then fi done fi - - train_cmd_args="--batch-size 64 \ + # --dataset ultra_chat_voice_assistant + train_cmd_args="--batch-size 30 \ --exp-dir $exp_dir \ - --last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --enable-speech-input False \ --deepspeed \ - --dataset /lustre/fsw/general_sa/yuekaiz/s2s/emilia_en \ + --dataset /lustre/fsw/general_sa/yuekaiz/s2s/VoxBox/manifests_emilia_en \ --deepspeed_config ./qwen_omni/ds_config_zero1.json \ --use-flash-attn True \ - --num-epochs 2 \ + --num-epochs 3 \ --use-lora False --unfreeze-llm False --enable-speech-output True" if [ "$latest_checkpoint_step" -ge 0 ]; then @@ -168,66 +169,66 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then fi -if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then - log "stage 20: Training TTS Model" - echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -" - if [ ! -L "/workspace/slam" ]; then - cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd - - fi - exp_dir=./qwen_omni/exp_test - ngpu=4 +# if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then +# log "stage 20: Training TTS Model" +# echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -" +# if [ ! -L "/workspace/slam" ]; then +# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd - +# fi +# exp_dir=./qwen_omni/exp_test +# ngpu=4 - latest_checkpoint_step=-1 - # Check if exp_dir exists and is a directory - if [ -d "$exp_dir" ]; then - # List directories matching checkpoint-* and find the one with the largest step number - for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do - checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000 - # Extract step number using parameter expansion - current_step=${checkpoint_name#checkpoint-} - # Ensure current_step is a number - if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then - latest_checkpoint_step=$current_step - fi - done - fi +# latest_checkpoint_step=-1 +# # Check if exp_dir exists and is a directory +# if [ -d "$exp_dir" ]; then +# # List directories matching checkpoint-* and find the one with the largest step number +# for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do +# checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000 +# # Extract step number using parameter expansion +# current_step=${checkpoint_name#checkpoint-} +# # Ensure current_step is a number +# if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then +# latest_checkpoint_step=$current_step +# fi +# done +# fi - train_cmd_args="--max-duration 150 \ - --enable-musan False \ - --exp-dir $exp_dir \ - --speech-encoder-path-or-name models/large-v2.pt \ - --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ - --dataset vocalnet_ultrachat_voiceassistant \ - --manifest-dir data/fbank \ - --deepspeed \ - --deepspeed_config ./qwen_omni/ds_config_zero1.json \ - --use-flash-attn True --on-the-fly-feats True \ - --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True" +# train_cmd_args="--max-duration 150 \ +# --enable-musan False \ +# --exp-dir $exp_dir \ +# --speech-encoder-path-or-name models/large-v2.pt \ +# --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ +# --dataset vocalnet_ultrachat_voiceassistant \ +# --manifest-dir data/fbank \ +# --deepspeed \ +# --deepspeed_config ./qwen_omni/ds_config_zero1.json \ +# --use-flash-attn True --on-the-fly-feats True \ +# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True" - if [ "$latest_checkpoint_step" -ge 0 ]; then - log "Continuing training from checkpoint-$latest_checkpoint_step" - step=$latest_checkpoint_step - train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt" - else - log "Starting training from scratch as no checkpoint was found in $exp_dir" - # No pretrained model or sampler state dict needed for the first run - fi +# if [ "$latest_checkpoint_step" -ge 0 ]; then +# log "Continuing training from checkpoint-$latest_checkpoint_step" +# step=$latest_checkpoint_step +# train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt" +# else +# log "Starting training from scratch as no checkpoint was found in $exp_dir" +# # No pretrained model or sampler state dict needed for the first run +# fi - torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \ - $train_cmd_args -fi +# torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \ +# $train_cmd_args +# fi -if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then - log "stage 21: TTS Decoding Test Set" - exp_dir=./qwen_omni/exp_tts - torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \ - --exp-dir $exp_dir \ - --speech-encoder-path-or-name models/large-v2.pt \ - --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \ - --use-flash-attn True \ - --enable-speech-output True \ - --token2wav-path /workspace/CosyVoice2-0.5B \ - --use-lora True -fi +# if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then +# log "stage 21: TTS Decoding Test Set" +# exp_dir=./qwen_omni/exp_tts +# torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \ +# --exp-dir $exp_dir \ +# --speech-encoder-path-or-name models/large-v2.pt \ +# --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ +# --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \ +# --use-flash-attn True \ +# --enable-speech-output True \ +# --token2wav-path /workspace/CosyVoice2-0.5B \ +# --use-lora True +# fi diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py index 3def803b5..baec602bb 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py @@ -437,7 +437,8 @@ class SPEECH_LLM(nn.Module): audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) - text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], [] + # text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], [] + text_input_embeds_list = [] for i in range(len(text_label_start_index_list)): text_last_hidden = model_outputs.hidden_states[-1][ i, @@ -445,14 +446,14 @@ class SPEECH_LLM(nn.Module): + input_seq_len[i] - 1, ] - text_last_hidden_lists.append(text_last_hidden) + # text_last_hidden_lists.append(text_last_hidden) text_embed = inputs_embeds[ i, text_input_start_index_list[i] + 1 : text_input_start_index_list[i] + input_seq_len[i], ] # exclude bos - text_embeds_list.append(text_embed) + # text_embeds_list.append(text_embed) text_input_embeds = torch.cat( [ @@ -480,8 +481,9 @@ class SPEECH_LLM(nn.Module): ), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}" if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx: logging.warning( - f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}" + f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}\naudio_codes_lens: {audio_codes_lens[i]}\ninput_question_len_list: {input_question_len_list[i]}\ninput_seq_len: {input_seq_len[i]}\n" ) + # breakpoint() text_input_embeds = text_input_embeds[ : audio_embeddings.shape[1] - start_idx ] diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py index 38132e71e..6d35ed6a9 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py @@ -68,6 +68,7 @@ from transformers import ( ) from torchdata.stateful_dataloader import StatefulDataLoader from torch.utils.data import DistributedSampler, DataLoader +from pathlib import Path from train import add_model_arguments, add_training_arguments, get_params, get_model from utils import ( # filter_uneven_sized_batch, @@ -171,11 +172,14 @@ def data_collator(batch): {"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"}, {"role": "assistant", "content": item["text"]}, ] + # message_list_item += [ + # {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"}, + # {"role": "assistant", "content": item["text"]}, + # ] messages.append(message_list_item) durations.append(item["duration"]) - ids.append(item["id"]) + ids.append(item["index"] if "index" in item else item["id"]) lang.append(item["language"]) - dnsmos.append(item["dnsmos"]) return { "speech_tokens": speech_tokens, @@ -183,7 +187,92 @@ def data_collator(batch): "durations": durations, "ids": ids, "lang": lang, - "dnsmos": dnsmos, + } + +def data_collator_concate_items(batch, concat_items_num: int = 3): + """Concatenate *concat_items_num* consecutive dataset items into one. + + The function groups the incoming ``batch`` (a list of dataset items) + into non-overlapping chunks of *concat_items_num*. For each group it + concatenates the textual fields and speech codec tokens so that the + model generates one longer utterance instead of several short ones. + + Any remainder (when ``len(batch)`` is not divisible by + *concat_items_num*) is also kept as a smaller group. + """ + + grouped_speech_tokens, grouped_messages, grouped_durations = [], [], [] + grouped_ids, grouped_lang = [], [] + + # Iterate over the batch in strides of *concat_items_num* + for start_idx in range(0, len(batch), concat_items_num): + group = batch[start_idx : start_idx + concat_items_num] + if not group: + continue + + # 1) Speech tokens -------------------------------------------------- + # ``item['code']`` can be a list[int] or a 1-D tensor. Use the first + # element to decide how to concatenate. + first_code = group[0]["code"] + if isinstance(first_code, torch.Tensor): + concat_code = torch.cat([item["code"] for item in group], dim=0) + else: + # assume list / iterable of ints + concat_code = [] + for item in group: + concat_code.extend(item["code"]) + + # 2) Text ----------------------------------------------------------- + concat_text = "".join([item["text"] for item in group]) + + # 3) Build chat template messages ----------------------------------- + message_list_item = [ + { + "role": "user", + "content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}", + }, + {"role": "assistant", "content": concat_text}, + ] + + # 4) Misc meta fields ---------------------------------------------- + total_duration = sum(item["duration"] for item in group) + group_ids = [item.get("index", item.get("id")) for item in group] + language = group[0].get("language", "") + + # 5) Append to output lists ---------------------------------------- + grouped_speech_tokens.append(concat_code) + grouped_messages.append(message_list_item) + grouped_durations.append(total_duration) + grouped_ids.append(group_ids) + grouped_lang.append(language) + + return { + "speech_tokens": grouped_speech_tokens, + "messages": grouped_messages, + "durations": grouped_durations, + "ids": grouped_ids, + "lang": grouped_lang, + } + +def data_collator_ultra_chat(batch): + speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], [] + for i, item in enumerate(batch): + speech_tokens.append(item["custom"]["speech_token"]) + text = item["supervisions"][0]["text"] + message_list_item = [] + message_list_item += [ + {"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": text}, + ] + messages.append(message_list_item) + durations.append(item["duration"]) + ids.append(item["id"]) + + return { + "speech_tokens": speech_tokens, + "messages": messages, + "durations": durations, + "ids": ids, } def compute_loss( @@ -470,13 +559,21 @@ def run(rank, world_size, args): sampler_state_dict = None if params.sampler_state_dict_path: sampler_state_dict = torch.load(params.sampler_state_dict_path) - # print(params.dataset) - ds = load_dataset(params.dataset, split="train") - # shuffle the dataset - ds = ds.shuffle(seed=42) - train_test_split = ds.train_test_split(test_size=1000, seed=42) - train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] - # train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"] + if params.dataset == "ultra_chat_voice_assistant": + data_dir = "data/fbank" + json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"] + ds = load_dataset("json", data_files=json_file_lists, split="train") + # shuffle the dataset + train_dataset = ds.shuffle(seed=42) + eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train") + else: + data_dir = Path(params.dataset) + json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")] + ds = load_dataset("json", data_files=json_file_lists, split="train") + # shuffle the dataset + ds = ds.shuffle(seed=42) + train_test_split = ds.train_test_split(test_size=1000, seed=42) + train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) train_dl = StatefulDataLoader( @@ -486,7 +583,7 @@ def run(rank, world_size, args): shuffle=False, num_workers=4, prefetch_factor=2, - collate_fn=data_collator + collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator ) train_dl.load_state_dict(sampler_state_dict) valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank) @@ -497,7 +594,7 @@ def run(rank, world_size, args): shuffle=False, num_workers=1, prefetch_factor=1, - collate_fn=data_collator + collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator ) if args.tensorboard and rank == 0: diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py index 0ebaa6eb4..81f7c0d5c 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py @@ -12,12 +12,12 @@ from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union - +from tqdm import tqdm import kaldialign import torch import torch.distributed as dist from torch.utils.tensorboard import SummaryWriter - +import numpy as np Pathlike = Union[str, Path] @@ -431,3 +431,45 @@ def write_error_stats( print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) + + +def get_dataset_statistics(dataset, save_filename): + speech_token_lengths = [] + text_lengths = [] + for item in tqdm(dataset): + if 'custom' not in item: + speech_token = item["code"] + text = item["text"] + else: + speech_token = item["custom"]["speech_token"] + text = item["supervisions"][0]["text"] + speech_token_lengths.append(len(speech_token)) + text_lengths.append(len(text)) + speech_token_length_array = np.array(speech_token_lengths) + text_length_array = np.array(text_lengths) + # 计算并存储统计指标 + def get_length_stats(lengths_array): + length_stats = [] + length_stats.append(["count", f"{len(lengths_array)}"]) # 总数 + length_stats.append(["mean", f"{np.mean(lengths_array):.1f}"]) + length_stats.append(["std", f"{np.std(lengths_array):.1f}"]) + length_stats.append(["min", f"{np.min(lengths_array):.1f}"]) + length_stats.append(["25%", f"{np.percentile(lengths_array, 25):.1f}"]) + length_stats.append(["50% (median)", f"{np.median(lengths_array):.1f}"]) # median 和 50% percentile 是一样的 + length_stats.append(["75%", f"{np.percentile(lengths_array, 75):.1f}"]) + length_stats.append(["99%", f"{np.percentile(lengths_array, 99):.1f}"]) + length_stats.append(["99.5%", f"{np.percentile(lengths_array, 99.5):.1f}"]) + length_stats.append(["99.9%", f"{np.percentile(lengths_array, 99.9):.1f}"]) + length_stats.append(["max", f"{np.max(lengths_array):.1f}"]) + return length_stats + speech_length_stats = get_length_stats(speech_token_length_array) + text_length_stats = get_length_stats(text_length_array) + with open(save_filename, "w") as f: + print("speech_tokens 长度统计指标:", file=f) + for stat_name, stat_value in speech_length_stats: + print(f"{stat_name:<15}: {stat_value}", file=f) + print("\ntext 长度统计指标:", file=f) + for stat_name, stat_value in text_length_stats: + print(f"{stat_name:<15}: {stat_value}", file=f) + + return speech_token_lengths, text_lengths