diff --git a/egs/speech_llm/SPEECH2SPEECH/README.md b/egs/speech_llm/SPEECH2SPEECH/README.md
new file mode 100644
index 000000000..9a0b62914
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/README.md
@@ -0,0 +1,55 @@
+
+# Introduction
+
+This recipe includes scripts for training speech2speech models.
+
+# SPEECH2SPEECH
+
+The following table lists the folders for different tasks.
+
+|Recipe | Speech Input | Speech Output | Comment|
+|--------------|--------------|---------------|--------|
+|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token |
+
+### [Qwen-omni like Speech2speech Recipe](./qwen_omni)
+
+[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset.
+
+
+
+
+
+
+
+Command for training is:
+```bash
+torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 50 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --manifest-dir data/fbank \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
+```
+
+Command for decoding is:
+```bash
+python3 ./qwen_omni/decode.py \
+ --max-duration 1 \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --epoch 999 --avg 1 \
+ --manifest-dir data/fbank \
+ --use-flash-attn True \
+ --method e2e-epoch10_speech2speech \
+ --enable-speech-output True \
+ --token2wav-path models/CosyVoice-300M-SFT \
+ --use-lora True
+```
+
+Please see [`prepare.sh`](./prepare.sh) for more details.
diff --git a/egs/speech_llm/SPEECH2SPEECH/assets/framework.png b/egs/speech_llm/SPEECH2SPEECH/assets/framework.png
new file mode 100644
index 000000000..6cd941a0b
Binary files /dev/null and b/egs/speech_llm/SPEECH2SPEECH/assets/framework.png differ
diff --git a/egs/speech_llm/SPEECH2SPEECH/exp.sh b/egs/speech_llm/SPEECH2SPEECH/exp.sh
new file mode 100644
index 000000000..26b2c8745
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/exp.sh
@@ -0,0 +1,234 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
+# export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
+set -eou pipefail
+
+stage=$1
+stop_stage=$2
+
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
+ echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
+ if [ ! -L "/workspace/slam" ]; then
+ cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
+ fi
+ log "stage 17: Training Speech2Speech Model, full parameters"
+ exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
+ pretrained_dir=./qwen_omni/exp_speech2text
+ ngpu=4
+
+ latest_checkpoint_step=-1
+ # Check if exp_dir exists and is a directory
+ if [ -d "$exp_dir" ]; then
+ # List directories matching checkpoint-* and find the one with the largest step number
+ for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
+ checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
+ # Extract step number using parameter expansion
+ current_step=${checkpoint_name#checkpoint-}
+ # Ensure current_step is a number
+ if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
+ latest_checkpoint_step=$current_step
+ fi
+ done
+ fi
+
+ train_cmd_args="--max-duration 200 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --on-the-fly-feats True --on-the-fly-speed-perturb False\
+ --deepspeed \
+ --huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True --on-the-fly-feats True \
+ --dataset vocalnet_ultrachat_voiceassistant_instruct_s2s --num-epochs 10 \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False"
+
+ if [ "$latest_checkpoint_step" -ge 0 ]; then
+ log "Continuing training from checkpoint-$latest_checkpoint_step"
+ step=$latest_checkpoint_step
+ train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
+ else
+ log "Starting training from scratch as no checkpoint was found in $exp_dir"
+ # No pretrained model or sampler state dict needed for the first run
+ fi
+
+ torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
+ $train_cmd_args
+fi
+
+if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
+ echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
+ # check if the link exists, if not exist, create it
+ if [ ! -L "/workspace/slam" ]; then
+ cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
+ fi
+ log "stage 17: Training Speech2Speech Model, full parameters"
+ exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
+ pretrained_dir=./qwen_omni/exp_speech2text
+ ngpu=4
+
+ latest_checkpoint_step=-1
+ # Check if exp_dir exists and is a directory
+ if [ -d "$exp_dir" ]; then
+ # List directories matching checkpoint-* and find the one with the largest step number
+ for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
+ checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
+ # Extract step number using parameter expansion
+ current_step=${checkpoint_name#checkpoint-}
+ # Ensure current_step is a number
+ if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
+ latest_checkpoint_step=$current_step
+ fi
+ done
+ fi
+
+ train_cmd_args="--max-duration 200 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --on-the-fly-feats True --on-the-fly-speed-perturb False\
+ --deepspeed \
+ --huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True --on-the-fly-feats True \
+ --dataset vocalnet_ultrachat_voiceassistant_instruct_s2s_librispeech --num-epochs 10 \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False"
+
+ if [ "$latest_checkpoint_step" -ge 0 ]; then
+ log "Continuing training from checkpoint-$latest_checkpoint_step"
+ step=$latest_checkpoint_step
+ train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
+ else
+ log "Starting training from scratch as no checkpoint was found in $exp_dir"
+ # No pretrained model or sampler state dict needed for the first run
+ fi
+
+ torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
+ $train_cmd_args
+fi
+
+if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
+ log "stage 19: Training TTS Model"
+ exp_dir=./qwen_omni/exp_tts_ultra_chat_voice_assistant
+ exp_dir=./qwen_omni/exp_tts_emilia_en_tts_only_template
+ exp_dir=./qwen_omni/exp_tts_emilia_en_tts_three_concat
+ pretrained_dir=./qwen_omni/exp_speech2text
+ ngpu=4
+
+ latest_checkpoint_step=-1
+ # Check if exp_dir exists and is a directory
+ if [ -d "$exp_dir" ]; then
+ # List directories matching checkpoint-* and find the one with the largest step number
+ for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
+ checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
+ # Extract step number using parameter expansion
+ current_step=${checkpoint_name#checkpoint-}
+ # Ensure current_step is a number
+ if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
+ latest_checkpoint_step=$current_step
+ fi
+ done
+ fi
+ # --dataset ultra_chat_voice_assistant
+ train_cmd_args="--batch-size 30 \
+ --exp-dir $exp_dir \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --enable-speech-input False \
+ --deepspeed \
+ --dataset /lustre/fsw/general_sa/yuekaiz/s2s/VoxBox/manifests_emilia_en \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --num-epochs 3 \
+ --use-lora False --unfreeze-llm False --enable-speech-output True"
+
+ if [ "$latest_checkpoint_step" -ge 0 ]; then
+ log "Continuing training from checkpoint-$latest_checkpoint_step"
+ step=$latest_checkpoint_step
+ train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
+ else
+ log "Starting training from scratch as no checkpoint was found in $exp_dir"
+ # No pretrained model or sampler state dict needed for the first run
+ fi
+
+ torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train_tts.py \
+ $train_cmd_args
+fi
+
+
+# if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
+# log "stage 20: Training TTS Model"
+# echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
+# if [ ! -L "/workspace/slam" ]; then
+# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
+# fi
+# exp_dir=./qwen_omni/exp_test
+# ngpu=4
+
+# latest_checkpoint_step=-1
+# # Check if exp_dir exists and is a directory
+# if [ -d "$exp_dir" ]; then
+# # List directories matching checkpoint-* and find the one with the largest step number
+# for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
+# checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
+# # Extract step number using parameter expansion
+# current_step=${checkpoint_name#checkpoint-}
+# # Ensure current_step is a number
+# if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
+# latest_checkpoint_step=$current_step
+# fi
+# done
+# fi
+
+# train_cmd_args="--max-duration 150 \
+# --enable-musan False \
+# --exp-dir $exp_dir \
+# --speech-encoder-path-or-name models/large-v2.pt \
+# --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+# --dataset vocalnet_ultrachat_voiceassistant \
+# --manifest-dir data/fbank \
+# --deepspeed \
+# --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+# --use-flash-attn True --on-the-fly-feats True \
+# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
+
+# if [ "$latest_checkpoint_step" -ge 0 ]; then
+# log "Continuing training from checkpoint-$latest_checkpoint_step"
+# step=$latest_checkpoint_step
+# train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
+# else
+# log "Starting training from scratch as no checkpoint was found in $exp_dir"
+# # No pretrained model or sampler state dict needed for the first run
+# fi
+
+# torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
+# $train_cmd_args
+# fi
+
+
+# if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
+# log "stage 21: TTS Decoding Test Set"
+# exp_dir=./qwen_omni/exp_tts
+# torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
+# --exp-dir $exp_dir \
+# --speech-encoder-path-or-name models/large-v2.pt \
+# --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+# --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
+# --use-flash-attn True \
+# --enable-speech-output True \
+# --token2wav-path /workspace/CosyVoice2-0.5B \
+# --use-lora True
+# fi
diff --git a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py
new file mode 100755
index 000000000..58d7cf3d6
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py
@@ -0,0 +1,291 @@
+#!/usr/bin/env python3
+# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
+# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
+# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
+# Copyright 2025 Nvidia (Yuekai Zhang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+ python3 local/compute_whisper_fbank.py \
+ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
+ --out-dir data/fbank \
+ --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
+ --audio-key question_audio --text-key answer \
+ --prefix ultrachat
+"""
+
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from datasets import load_dataset
+from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig
+from vocalnet_lhotse_cutset import LazyCustomDatasetIterator
+
+from icefall.utils import str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=True,
+ help="Use WhisperFbank instead of Fbank. Default: False.",
+ )
+ parser.add_argument(
+ "--resample-to-16kHz",
+ type=str2bool,
+ default=True,
+ help="Resample audio to 16kHz. Default: False.",
+ )
+ parser.add_argument(
+ "--speed-perturb",
+ type=str2bool,
+ default=False,
+ help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
+ )
+ parser.add_argument(
+ "--out-dir",
+ type=str,
+ default="data/fbank",
+ help="Output directory for the computed features",
+ )
+ parser.add_argument(
+ "--huggingface-dataset-path-or-name",
+ type=str,
+ default="/workspace/Belle_1.4M-SLAM-Omni",
+ help="The path or name of the Huggingface dataset",
+ )
+ parser.add_argument(
+ "--audio-key",
+ type=str,
+ default="question_audio",
+ help="The key in the Huggingface dataset containing the audio data",
+ )
+ parser.add_argument(
+ "--text-key",
+ type=str,
+ default="answer",
+ help="The key in the Huggingface dataset containing the text data",
+ )
+ parser.add_argument(
+ "--prefix",
+ type=str,
+ default="belle",
+ help="""The dataset prefix to use when saving the features""",
+ )
+ parser.add_argument(
+ "--json-file-path",
+ type=str,
+ default=None,
+ help="The path to the json file containing the vocalnet data",
+ )
+ parser.add_argument(
+ "--drop-recordings",
+ type=str2bool,
+ default=True,
+ help="Drop recordings. Default: False.",
+ )
+ parser.add_argument(
+ "--subset",
+ type=str,
+ default=None,
+ help="The subset to use from the Huggingface dataset",
+ )
+ parser.add_argument(
+ "--split",
+ type=str,
+ default="train",
+ help="The split to use from the Huggingface dataset",
+ )
+ return parser
+
+
+def remove_short_and_long_utt(c):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 50.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+ return True
+
+
+def compute_fbank(args):
+ in_out_dir = Path(args.out_dir)
+ in_out_dir.mkdir(parents=True, exist_ok=True)
+ # number of workers in dataloader
+ num_workers = 4
+
+ # number of seconds in a batch
+ batch_duration = 10
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ if args.whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
+ )
+ else:
+ raise NotImplementedError("Only WhisperFbank is implemented.")
+
+ logging.info(f"device: {device}")
+
+ dataset = load_dataset(
+ args.huggingface_dataset_path_or_name,
+ args.subset,
+ streaming=True,
+ split=args.split,
+ )
+ num_shards = dataset.num_shards
+ num_digits = 5
+ for i in range(252, num_shards):
+ shard = dataset.shard(num_shards, i)
+ # shard = shard.take(10) # for testing
+ logging.info(
+ f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
+ )
+
+ idx = f"{i}".zfill(num_digits)
+
+ cut_set = CutSet.from_huggingface_dataset(
+ shard, audio_key=args.audio_key, text_key=args.text_key
+ )
+
+ cut_set = cut_set.filter(remove_short_and_long_utt)
+ if args.resample_to_16kHz:
+ cut_set = cut_set.resample(16000)
+ if args.speed_perturb:
+ cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+
+ logging.info("Computing features")
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{in_out_dir}/feats_{idx}_{args.subset}",
+ num_workers=num_workers,
+ batch_duration=batch_duration,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+ # cut_set = cut_set.trim_to_supervisions(
+ # keep_overlapping=False, min_duration=None
+ # )
+ cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.{args.subset}.jsonl.gz"
+ logging.info(f"Saving to {cuts_path}")
+ # see https://github.com/lhotse-speech/lhotse/issues/1125
+ if args.drop_recordings:
+ cut_set.drop_recordings().to_file(cuts_path)
+ else:
+ cut_set.to_file(cuts_path)
+
+
+def compute_fbank_vocalnet(args):
+ in_out_dir = Path(args.out_dir)
+ in_out_dir.mkdir(parents=True, exist_ok=True)
+ # number of workers in dataloader
+ num_workers = 4
+
+ # number of seconds in a batch
+ batch_duration = 10
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ if args.whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
+ )
+ else:
+ raise NotImplementedError("Only WhisperFbank is implemented.")
+
+ logging.info(f"device: {device}")
+
+ num_shards = 50
+ num_digits = 5
+ for i in range(num_shards):
+ logging.info(f"Processing shard {i}")
+ idx = f"{i}".zfill(num_digits)
+ cut_set = CutSet(
+ LazyCustomDatasetIterator(
+ json_file_path=args.json_file_path, shard_id=i, num_shards=num_shards
+ )
+ )
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False, min_duration=None
+ )
+
+ if args.resample_to_16kHz:
+ cut_set = cut_set.resample(16000)
+ if args.speed_perturb:
+ cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+
+ logging.info("Computing features")
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{in_out_dir}/feats_{idx}",
+ num_workers=num_workers,
+ batch_duration=batch_duration,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+ cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz"
+ logging.info(f"Saving to {cuts_path}")
+ # see https://github.com/lhotse-speech/lhotse/issues/1125
+ cut_set.to_file(cuts_path)
+
+
+def main():
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+ if args.json_file_path is not None:
+ compute_fbank_vocalnet(args)
+ else:
+ compute_fbank(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/speech_llm/SPEECH2SPEECH/local/vocalnet_lhotse_cutset.py b/egs/speech_llm/SPEECH2SPEECH/local/vocalnet_lhotse_cutset.py
new file mode 100644
index 000000000..f7519fbfe
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/local/vocalnet_lhotse_cutset.py
@@ -0,0 +1,99 @@
+# https://huggingface.co/datasets/VocalNet/UltraChat-vocalnet/blob/main/UltraChat.json
+# https://huggingface.co/datasets/VocalNet/VoiceAssistant-430K-vocalnet/blob/main/VoiceAssistant-430K.json
+import json
+import os
+
+import numpy as np
+from lhotse import CutSet
+from lhotse.audio import Recording
+from lhotse.supervision import SupervisionSegment
+
+
+class LazyCustomDatasetIterator:
+ """
+ Thin wrapper on top of HF datasets objects that allows to interact with them through a Lhotse CutSet.
+ It can be initialized with an existing HF dataset, or args/kwargs passed on to ``datasets.load_dataset()``.
+ Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples
+ returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively.
+ The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary.
+ Example with existing HF dataset::
+ >>> import datasets
+ ... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")
+ ... dataset = dataset.map(some_transform)
+ ... cuts_it = LazyHFDatasetIterator(dataset)
+ ... for cut in cuts_it:
+ ... pass
+ Example providing HF dataset init args/kwargs::
+ >>> import datasets
+ ... cuts_it = LazyHFDatasetIterator("mozilla-foundation/common_voice_11_0", "hi", split="test")
+ ... for cut in cuts_it:
+ ... pass
+ """
+
+ def __init__(self, json_file_path: str, shard_id: int = 0, num_shards: int = 100):
+ self.json_file_path = json_file_path
+ self.shard_id = shard_id
+ self.num_shards = num_shards
+
+ def __iter__(self):
+
+ with open(self.json_file_path, "r", encoding="utf-8") as f:
+ list_data_dict = json.load(f)
+ list_data_dict = list_data_dict[self.shard_id :: self.num_shards]
+ for item in list_data_dict:
+ custom_data = item.copy()
+ json_file_parent_of_parent_dir = os.path.dirname(
+ os.path.dirname(self.json_file_path)
+ )
+ units_path = os.path.join(
+ json_file_parent_of_parent_dir, custom_data["units"]
+ )
+ speech_token_dict = np.load(units_path, allow_pickle=True).item()
+ speech_token = speech_token_dict["speech_token"].squeeze(0).tolist()
+ speech_token_len = speech_token_dict["speech_token_len"]
+
+ assert len(speech_token) == speech_token_len
+ custom_data["speech_token"] = speech_token
+ audio_path = custom_data.pop("speech", None)
+ audio_path = os.path.join(json_file_parent_of_parent_dir, audio_path)
+ item_id = item.get("id")
+ recording = Recording.from_file(path=audio_path, recording_id=item_id)
+
+ conversations = item.get("conversations")
+ assert isinstance(conversations, list) and len(conversations) == 2
+ for conv in conversations:
+ if isinstance(conv, dict) and conv.get("from") == "gpt":
+ gpt_text = conv.get("value")
+ break
+ assert gpt_text is not None
+
+ supervision = SupervisionSegment(
+ id=item_id,
+ recording_id=recording.id,
+ start=0.0, # Assuming the supervision covers the entire recording
+ duration=recording.duration,
+ text=gpt_text,
+ )
+
+ cut = recording.to_cut()
+ # cut.id will be the same as recording.id
+
+ cut.supervisions = [supervision]
+ # custom_data contains the original item's fields, minus "speech".
+ # So, "id", "conversations", "units", etc., are preserved here.
+ custom_data.pop("conversations")
+ custom_data.pop("units")
+ cut.custom = custom_data
+
+ yield cut
+
+
+if __name__ == "__main__":
+ json_file_path = (
+ "/workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json"
+ )
+ cut_set = CutSet(LazyCustomDatasetIterator(json_file_path=json_file_path))
+
+ for cut in cut_set:
+ print(cut)
+ input()
diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh
new file mode 100644
index 000000000..a75cd33ff
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh
@@ -0,0 +1,444 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+export PYTHONPATH=$PYTHONPATH:/workspace/icefall
+
+set -eou pipefail
+
+stage=$1
+stop_stage=$2
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "stage 0: Clone CosyVoice repo and install requirements inside the container"
+ # docker: ghcr.io/swivid/f5-tts:main
+ pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
+ git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
+ cd /workspace/CosyVoice
+ # If you failed to clone submodule due to network failures, please run following command until success
+ git submodule update --init --recursive
+ pip install -r qwen_omni/requirements.txt
+ pip install -r qwen_omni/requirements-cosyvoice.txt
+
+ # For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
+ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
+ # Cosyvoice pretrained model for speech token2wav module
+ huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
+ # Qwen Pretrained model
+ huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
+ # Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
+ huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
+
+ # For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
+ pip install sherpa-onnx
+ model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
+ if [ ! -d $model_path ]; then
+ wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
+ tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
+ fi
+fi
+export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "stage 1: Compute fbank feature from huggingface"
+ python3 local/compute_whisper_fbank.py \
+ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
+ --out-dir data/fbank_test \
+ --huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
+ --audio-key question_audio --text-key answer \
+ --prefix belle
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Combine features"
+ manifest_dir=data/fbank
+ if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
+ mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
+ # exclude cust_belle_00000.jsonl.gz for valid and test set
+ pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
+ echo $pieces | wc
+ lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
+ mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
+ cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
+ ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
+ fi
+fi
+
+ngpu=8
+exp_dir=./qwen_omni/exp_speech2speech
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "stage 3: Training Speech2Speech Model"
+ torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 50 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --manifest-dir data/fbank \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "stage 4: Decoding, only support batch_size=1 for now."
+ cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
+ python3 ./qwen_omni/decode.py \
+ --max-duration 1 \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --epoch 999 --avg 1 \
+ --manifest-dir data/fbank \
+ --use-flash-attn True \
+ --method e2e-epoch10_speech2speech \
+ --enable-speech-output True \
+ --token2wav-path models/CosyVoice-300M-SFT \
+ --use-lora True
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "stage 5: Gradio Demo"
+ python3 ./qwen_omni/web_demo.py \
+ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --checkpoint-path $exp_dir/epoch-999.pt \
+ --use-flash-attn True \
+ --enable-speech-output True \
+ --asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
+ --use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "stage 6: Compute fbank feature from huggingface"
+ # CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
+ # --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
+ # --out-dir data/fbank_voice_assistant \
+ # --huggingface-dataset-path-or-name worstchan/VoiceAssistant-400K-SLAM-Omni \
+ # --audio-key question_audio --text-key answer \
+ # --prefix voice_assistant
+ CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
+ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
+ --out-dir data/fbank_voice_assistant_cosy2 \
+ --json-file-path /workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json \
+ --prefix voice_assistant
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "stage 7: Compute fbank feature from huggingface"
+ # CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
+ # --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
+ # --out-dir data/fbank_ultrachat \
+ # --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
+ # --audio-key question_audio --text-key answer \
+ # --prefix ultrachat
+ CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
+ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
+ --out-dir data/fbank_ultrachat_cosy2 \
+ --json-file-path /workspace/slam/UltraChat-vocalnet/UltraChat.json \
+ --prefix ultrachat
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "stage 8: Compute fbank feature from huggingface"
+
+ CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
+ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
+ --out-dir data/fbank_gigaspeech \
+ --huggingface-dataset-path-or-name speechcolab/gigaspeech \
+ --subset test --split test \
+ --audio-key audio --text-key text \
+ --prefix gigaspeech
+
+ CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
+ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
+ --out-dir data/fbank_gigaspeech \
+ --huggingface-dataset-path-or-name speechcolab/gigaspeech \
+ --subset xl --split train \
+ --audio-key audio --text-key text \
+ --prefix gigaspeech
+fi
+
+# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
+ngpu=4
+exp_dir=./qwen_omni/exp_speech2speech_en
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+ log "stage 10: Training Speech2Speech Model"
+ torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 150 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --dataset-format vocalnet \
+ --manifest-dir data/fbank \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True --on-the-fly-feats True \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
+fi
+
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+ log "stage 11: Decoding EN, val set only support batch_size=1 for now."
+ exp_dir=./qwen_omni/exp_speech2speech_en_continue
+ # cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
+ python3 ./qwen_omni/decode.py \
+ --max-duration 1 \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --epoch 997 --avg 1 \
+ --manifest-dir data/fbank \
+ --use-flash-attn True \
+ --method e2e-epoch4_speech2speech \
+ --enable-speech-output True \
+ --token2wav-path /workspace/CosyVoice2-0.5B \
+ --use-lora True
+fi
+
+
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+ log "stage 12: Decoding EN voicebench"
+ exp_dir=./qwen_omni/exp_speech2speech_en_continue
+ torchrun --nproc_per_node=2 \
+ ./qwen_omni/decode_dist.py \
+ --output-dir $exp_dir/log_voicebench \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --use-flash-attn True \
+ --enable-speech-output True \
+ --checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
+ --use-lora True --subset-name openbookqa --split-name test
+fi
+
+
+if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
+ log "stage 13: Server"
+ exp_dir=./qwen_omni/exp_speech2speech_en_continue
+ python3 ./qwen_omni/server.py \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
+ --use-flash-attn True \
+ --enable-speech-output True \
+ --use-lora True
+fi
+
+if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
+ log "stage 14: Client"
+ exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
+ exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
+ exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
+ exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
+ # exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
+ # The final assignment of datasets in the original script is used here:
+ # (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
+ declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu")
+ declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh")
+ declare -a target_datasets=("mmsu")
+
+ NUM_CLIENT_JOBS=4 # Number of parallel client jobs
+ BASE_PORT=8000 # Base port for servers
+
+ log "Starting $NUM_CLIENT_JOBS parallel client jobs to process ${#target_datasets[@]} datasets."
+
+ for job_id in $(seq 0 $(($NUM_CLIENT_JOBS - 1)))
+ do
+ ( # Start a subshell for backgrounding this client job's tasks
+ current_port=$(expr $BASE_PORT + $job_id)
+ log "Client Job $job_id: Initializing. Will connect to port $current_port."
+
+ processed_count_for_this_job=0
+ # Iterate over all datasets using their indices
+ for i in "${!target_datasets[@]}"; do
+ # Assign dataset to job_id in a round-robin fashion
+ if [ $(($i % $NUM_CLIENT_JOBS)) -eq $job_id ]; then
+ dataset="${target_datasets[$i]}"
+
+ # local split_name # Determine split_name based on dataset
+ if [ "$dataset" == "sd-qa" ]; then
+ split_name="usa"
+ else
+ split_name="test"
+ fi
+
+ log "Client Job $job_id (Port $current_port): Processing dataset '$dataset' (split '$split_name')"
+ python3 ./qwen_omni/client.py \
+ --subset-name "$dataset" \
+ --split-name "$split_name" \
+ --output-dir "$exp_dir/results" \
+ --port "$current_port" # Assuming client.py accepts --port
+
+ if [ $? -ne 0 ]; then
+ log "Client Job $job_id (Port $current_port): ERROR processing dataset '$dataset'."
+ fi
+ processed_count_for_this_job=$(($processed_count_for_this_job + 1))
+ fi
+ done
+ log "Client Job $job_id (Port $current_port): Finished. Processed $processed_count_for_this_job datasets."
+ ) & # Run this client job's subshell in the background
+ done
+
+ log "All client jobs launched. Waiting for completion..."
+ wait # Wait for all backgrounded client jobs to complete
+ log "All client jobs have completed."
+fi
+
+if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
+ log "stage 15: Training Speech2Speech Model, adaptor only"
+ exp_dir=./qwen_omni/exp_speech2text
+ ngpu=2
+ torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 700 \
+ --enable-musan False \
+ --audio-key audio --text-key continuation \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --on-the-fly-feats True \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --dataset-format speech_continuation \
+ --start-epoch 4 --pretrained-model-path $exp_dir/epoch-3/pytorch_model.bin \
+ --use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
+fi
+
+if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
+ log "stage 16: Training Speech2Speech Model, adaptor only"
+ exp_dir=./qwen_omni/exp_speech2text
+ ngpu=4
+
+ latest_checkpoint_step=-1
+ # Check if exp_dir exists and is a directory
+ if [ -d "$exp_dir" ]; then
+ # List directories matching checkpoint-* and find the one with the largest step number
+ for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
+ checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
+ # Extract step number using parameter expansion
+ current_step=${checkpoint_name#checkpoint-}
+ # Ensure current_step is a number
+ if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
+ latest_checkpoint_step=$current_step
+ fi
+ done
+ fi
+
+ train_cmd_args="--max-duration 800 \
+ --enable-musan False \
+ --audio-key audio --text-key continuation \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --on-the-fly-feats True \
+ --deepspeed \
+ --huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --dataset-format speech_continuation \
+ --use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False"
+
+ if [ "$latest_checkpoint_step" -ge 0 ]; then
+ log "Continuing training from checkpoint-$latest_checkpoint_step"
+ step=$latest_checkpoint_step
+ train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
+ else
+ log "Starting training from scratch as no checkpoint was found in $exp_dir"
+ # No pretrained model or sampler state dict needed for the first run
+ fi
+
+ torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
+ $train_cmd_args
+fi
+
+
+if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
+ # pip install gradio sherpa-onnx
+ log "stage 17: Server for adapter only speech continuation"
+ exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
+ exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
+ exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
+ exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
+ exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
+
+ N_GPUS=4 # Define the number of GPUs/processes you want to launch
+
+ for id in $(seq 0 $(($N_GPUS - 1)))
+ do
+ log "Launching server on GPU $id with port $(expr 8000 + $id)"
+ CUDA_VISIBLE_DEVICES=$id python3 ./qwen_omni/server.py \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --checkpoint-path $exp_dir/checkpoint-55276/pytorch_model.bin \
+ --use-flash-attn True \
+ --enable-speech-output False \
+ --port $(expr 18000 + $id) \
+ --use-lora True &
+ done
+
+ wait # Wait for all background processes to complete
+fi
+
+if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
+ log "stage 18: Training kl-div Speech2Speech Model, adaptor only"
+ exp_dir=./qwen_omni/exp_speech2text_kl
+ ngpu=2
+ torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 700 \
+ --enable-musan False \
+ --audio-key audio --text-key continuation \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --on-the-fly-feats True \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --dataset-format speech_continuation \
+ --loss-type kl_div --dataset librispeech \
+ --pretrained-model-path $exp_dir/checkpoint-1001/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-1001/sampler.pt \
+ --use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
+fi
+
+if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
+ log "stage 19: Server for kl loss"
+ exp_dir=./qwen_omni/exp_speech2text_kl
+ python3 ./qwen_omni/server.py \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+ --checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
+ --use-flash-attn True \
+ --enable-speech-output False \
+ --use-lora False --prompt-template qa
+fi
+
+if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
+ log "stage 20: Training Speech2Speech Model, adaptor + lora, second stage"
+ exp_dir=./qwen_omni/exp_speech2text_kl_llm
+ pretrained_dir=./qwen_omni/exp_speech2text_kl
+ ngpu=2
+ torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 200 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/large-v2.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --pretrained-model-path $pretrained_dir/epoch-10/pytorch_model.bin \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False --dataset-format vocalnet
+fi
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py
new file mode 100644
index 000000000..7dc279e48
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py
@@ -0,0 +1,156 @@
+# client.py
+import argparse
+import json
+import os
+
+import requests
+from datasets import concatenate_datasets, load_dataset
+from tqdm import tqdm
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Speech-to-Text Client")
+ parser.add_argument(
+ "--server-url",
+ type=str,
+ default="http://localhost",
+ help="URL of the FastAPI server",
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8000,
+ help="Port of the FastAPI server",
+ )
+ parser.add_argument(
+ "--dataset-name",
+ type=str,
+ default="hlt-lab/voicebench",
+ help="Hugging Face dataset name",
+ )
+ parser.add_argument(
+ "--subset-name",
+ type=str,
+ default="commoneval", # Adjust as needed
+ help="Dataset subset name",
+ )
+ parser.add_argument(
+ "--split-name",
+ type=str,
+ default=None, # Adjust as needed
+ help="Dataset split name",
+ )
+ parser.add_argument(
+ "--output-dir", required=True, type=str, help="Directory to save results"
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = get_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+ output_filename = os.path.join(
+ args.output_dir,
+ f"{args.subset_name}-{args.split_name}.jsonl",
+ )
+ server_decode_url = f"{args.server_url}:{args.port}/decode"
+
+ print("Loading dataset...")
+ if args.subset_name != "mmsu":
+ dataset = load_dataset(
+ args.dataset_name,
+ args.subset_name,
+ split=args.split_name,
+ trust_remote_code=True,
+ )
+ else:
+ # load all splits and concatenate them
+ dataset = load_dataset(
+ args.dataset_name,
+ args.subset_name,
+ trust_remote_code=True,
+ )
+ dataset = concatenate_datasets([dataset[subset] for subset in dataset])
+
+ print(f"Dataset loaded with {len(dataset)} samples.")
+ print(f"Sending requests to {server_decode_url}...")
+ print(f"Saving results to {output_filename}")
+
+ with open(output_filename, "w", encoding="utf-8") as outfile:
+ # Iterate directly over the dataset
+ progress_bar = tqdm(dataset, desc="Processing", unit="samples")
+ for item in progress_bar:
+
+ audio_info = item.get("audio")
+ assert (
+ audio_info["sampling_rate"] == 16000
+ ), f"Sampling rate is {audio_info['sampling_rate']}, not 16khz"
+
+ # Prepare data for JSON serialization and server request
+ audio_array = audio_info["array"].tolist() # Convert numpy array to list
+ result_dict = {}
+ for key in item.keys():
+ if key != "audio":
+ # Ensure other fields are JSON serializable
+ try:
+ # Attempt to serialize to catch issues early (optional)
+ json.dumps(item[key])
+ result_dict[key] = item[key]
+ except (TypeError, OverflowError):
+ print(
+ f"Warning: Converting non-serializable key '{key}' to string."
+ )
+ result_dict[key] = str(
+ item[key]
+ ) # Convert problematic types to string
+
+ payload = {
+ "audio": audio_array,
+ "sampling_rate": 16000,
+ }
+
+ try:
+ response = requests.post(server_decode_url, json=payload, timeout=60)
+ response.raise_for_status()
+ server_response = response.json()
+ decoded_text = server_response.get("text", "")
+
+ # Add the response to the result dictionary
+ result_dict["response"] = decoded_text
+ print(result_dict)
+ # Write result to JSONL file
+ json.dump(result_dict, outfile, ensure_ascii=False)
+ outfile.write("\n")
+
+ except requests.exceptions.RequestException as e:
+ print(f"\nError sending request for an item: {e}")
+ error_entry = result_dict # Use the data prepared so far
+ error_entry["error"] = str(e)
+ error_entry["response"] = ""
+ json.dump(error_entry, outfile, ensure_ascii=False)
+ outfile.write("\n")
+ except json.JSONDecodeError:
+ print("\nError decoding server response for an item.")
+ error_entry = result_dict
+ error_entry["error"] = "Invalid JSON response from server"
+ error_entry["response"] = ""
+ json.dump(error_entry, outfile, ensure_ascii=False)
+ outfile.write("\n")
+ except Exception as e:
+ print(f"\nUnexpected error processing an item: {e}")
+ error_entry = result_dict
+ error_entry["error"] = f"Unexpected error: {str(e)}"
+ error_entry["response"] = ""
+ json.dump(error_entry, outfile, ensure_ascii=False)
+ outfile.write("\n")
+
+ # Progress bar updates automatically by iterating over tqdm(dataset)
+
+ # No need to close progress_bar explicitly when iterating directly
+
+ print("Processing finished.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py
new file mode 100644
index 000000000..457c3e107
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py
@@ -0,0 +1,813 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from datasets import interleave_datasets, load_dataset, Audio, Features, Value, Sequence
+from lhotse import (
+ CutSet,
+ WhisperFbank,
+ WhisperFbankConfig,
+ load_manifest,
+ load_manifest_lazy,
+)
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PerturbSpeed,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+from utils import get_local_rank, str2bool
+import io
+import wave
+import random
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class AsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=300.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--on-the-fly-speed-perturb",
+ type=str2bool,
+ default=True,
+ help="When enabled, use on-the-fly speed perturbation. "
+ "Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=4,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ group.add_argument(
+ "--huggingface-dataset-path-or-name",
+ type=str,
+ default=None,
+ help="The path or name of the Huggingface dataset",
+ )
+ group.add_argument(
+ "--audio-key",
+ type=str,
+ default=None,
+ help="The key in the Huggingface dataset containing the audio data",
+ )
+ group.add_argument(
+ "--text-key",
+ type=str,
+ default=None,
+ help="The key in the Huggingface dataset containing the text data",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+ if self.args.on_the_fly_speed_perturb and self.args.on_the_fly_feats:
+ transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ rank = get_local_rank()
+
+ train = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(
+ WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
+ )
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 1000,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=True if self.args.num_workers > 0 else False,
+ pin_memory=True,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ """
+ Args:
+ cuts_valid:
+ CutSet for validation.
+ """
+ logging.info("About to create dev dataset")
+ rank = get_local_rank()
+ validate = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(
+ WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
+ )
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ if self.args.bucketing_sampler:
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ else:
+ valid_sampler = SimpleCutSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_num_workers = 1
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=valid_num_workers,
+ persistent_workers=True if valid_num_workers > 0 else False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(
+ WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
+ )
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def test_cuts_belle(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return {
+ "test": load_manifest_lazy(
+ self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
+ )
+ }
+ @lru_cache()
+ def dev_cuts_belle(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
+ )
+ @lru_cache()
+ def train_cuts_belle(self) -> CutSet:
+ logging.info("About to get train cuts")
+ slam_omni_zh_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_belle_train.jsonl.gz"
+ )
+ return slam_omni_zh_cuts
+
+ @lru_cache()
+ def train_cuts_en_vocalnet(self) -> CutSet:
+ logging.info("About to get train cuts")
+ VoiceAssistant_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
+ )
+ ultrachat_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
+ )
+ VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
+ ultrachat_cuts = ultrachat_cuts.resample(16000)
+ return CutSet.mux(
+ VoiceAssistant_cuts,
+ ultrachat_cuts,
+ weights=[
+ len(VoiceAssistant_cuts),
+ len(ultrachat_cuts),
+ ],
+ )
+ @lru_cache()
+ def valid_cuts_en_vocalnet(self) -> CutSet:
+ logging.info("About to get valid cuts")
+ VoiceAssistant_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
+ )
+ VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
+ return VoiceAssistant_cuts
+
+ @lru_cache()
+ def test_cuts_en_vocalnet(self) -> CutSet:
+ logging.info("About to get test cuts")
+ VoiceAssistant_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
+ )
+ VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
+ return {"test": VoiceAssistant_cuts}
+
+ @lru_cache()
+ def train_cuts_ultravox(self) -> CutSet:
+ logging.info("About to get train cuts")
+ if self.args.huggingface_dataset_path_or_name is not None:
+ librispeech_path = (
+ self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
+ )
+ people_speech_path = (
+ self.args.huggingface_dataset_path_or_name + "/peoples_speech"
+ )
+ gigaspeech_path = self.args.huggingface_dataset_path_or_name + "/gigaspeech"
+ else:
+ librispeech_path = "fixie-ai/librispeech_asr"
+ people_speech_path = "fixie-ai/peoples_speech"
+ gigaspeech_path = "fixie-ai/gigaspeech"
+ # 148_688
+ librispeech_other = load_dataset(
+ librispeech_path, "other", split="train.500", streaming=True
+ )
+ # 104_014
+ librispeech_clean_360 = load_dataset(
+ librispeech_path, "clean", split="train.360", streaming=True
+ )
+ # 28_539
+ librispeech_clean_100 = load_dataset(
+ librispeech_path, "clean", split="train.100", streaming=True
+ )
+
+ # 1_501_271
+ people_speech_clean = load_dataset(
+ people_speech_path, "clean", split="train", streaming=True
+ )
+ # 548_000
+ people_speech_dirty_sa = load_dataset(
+ people_speech_path, "dirty_sa", split="train", streaming=True
+ )
+
+ # 8_266_422
+
+ gigaspeech = load_dataset(
+ gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
+ )
+
+ librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
+ librispeech_clean_100,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ librispeech_other_cuts = CutSet.from_huggingface_dataset(
+ librispeech_other,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
+ librispeech_clean_360,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ gigaspeech_cuts = CutSet.from_huggingface_dataset(
+ gigaspeech, audio_key="audio", text_key="text"
+ )
+
+ people_speech_clean_cuts = CutSet.from_huggingface_dataset(
+ people_speech_clean,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
+ people_speech_dirty_sa,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ return CutSet.mux(
+ librispeech_clean_100_cuts,
+ librispeech_clean_360_cuts,
+ librispeech_other_cuts,
+ gigaspeech_cuts,
+ people_speech_clean_cuts,
+ people_speech_dirty_sa_cuts,
+ weights=[
+ 28539,
+ 104014,
+ 148688,
+ 8266422,
+ 1501271,
+ 548000,
+ ],
+ )
+
+ @lru_cache()
+ def valid_cuts_ultravox(self) -> CutSet:
+ logging.info("About to get valid cuts")
+ librispeech_path = "fixie-ai/librispeech_asr"
+ librispeech_clean_valid = load_dataset(
+ librispeech_path, "clean", split="validation", streaming=True
+ )
+ librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
+ librispeech_clean_valid,
+ audio_key="audio",
+ text_key="text",
+ )
+ return librispeech_clean_valid_cuts
+
+ @lru_cache()
+ def train_cuts_librispeech(self) -> CutSet:
+ logging.info("About to get train cuts")
+ if self.args.huggingface_dataset_path_or_name is not None:
+ librispeech_path = self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
+ else:
+ librispeech_path = "fixie-ai/librispeech_asr"
+ # 148_688
+ librispeech_other = load_dataset(
+ librispeech_path, "other", split="train.500", streaming=True
+ )
+ # 104_014
+ librispeech_clean_360 = load_dataset(
+ librispeech_path, "clean", split="train.360", streaming=True
+ )
+ # 28_539
+ librispeech_clean_100 = load_dataset(
+ librispeech_path, "clean", split="train.100", streaming=True
+ )
+
+ librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
+ librispeech_clean_100,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ librispeech_other_cuts = CutSet.from_huggingface_dataset(
+ librispeech_other,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
+ librispeech_clean_360,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ return CutSet.mux(
+ librispeech_clean_100_cuts,
+ librispeech_clean_360_cuts,
+ librispeech_other_cuts,
+ weights=[
+ 28539,
+ 104014,
+ 148688,
+ ],
+ )
+
+ @lru_cache()
+ def train_cuts_gigaspeech(self) -> CutSet:
+ logging.info("About to get train cuts")
+ gigaspeech_path = "fixie-ai/gigaspeech"
+ gigaspeech = load_dataset(
+ gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
+ )
+
+ gigaspeech_cuts = CutSet.from_huggingface_dataset(
+ gigaspeech, audio_key="audio", text_key="text"
+ )
+
+ return gigaspeech_cuts
+
+ @lru_cache()
+ def train_cuts_instruct_s2s(self) -> CutSet:
+ logging.info("About to get train cuts")
+ if self.args.huggingface_dataset_path_or_name is not None:
+ data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
+ else:
+ data_path = "yuekai/InstructS2S-200K"
+ # 148_688
+ instruct_s2s_train = load_dataset(
+ data_path, split="train", streaming=True
+ )
+
+ instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
+ instruct_s2s_train,
+ audio_key="question_audio",
+ text_key="answer",
+ )
+
+ instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
+
+ return instruct_s2s_train_cuts
+
+ @lru_cache()
+ def train_cuts_en_speech2speech(self) -> CutSet:
+ logging.info("About to get train cuts")
+ VoiceAssistant_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
+ )
+ ultrachat_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
+ )
+
+ if self.args.huggingface_dataset_path_or_name is not None:
+ data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
+ else:
+ data_path = "yuekai/InstructS2S-200K"
+ # 148_688
+ instruct_s2s_train = load_dataset(
+ data_path, split="train", streaming=True
+ )
+
+ instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
+ instruct_s2s_train,
+ audio_key="question_audio",
+ text_key="answer",
+ )
+
+ instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
+
+
+ return CutSet.mux(
+ VoiceAssistant_cuts,
+ ultrachat_cuts,
+ instruct_s2s_train_cuts,
+ weights=[
+ len(VoiceAssistant_cuts),
+ len(ultrachat_cuts),
+ 423_000,
+ ],
+ )
+
+ @lru_cache()
+ def train_cuts_en_speech2speech_librispeech(self) -> CutSet:
+ logging.info("About to get train cuts")
+ VoiceAssistant_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
+ )
+ ultrachat_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
+ )
+
+ if self.args.huggingface_dataset_path_or_name is not None:
+ data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
+ else:
+ data_path = "yuekai/InstructS2S-200K"
+ # 148_688
+ instruct_s2s_train = load_dataset(
+ data_path, split="train", streaming=True
+ )
+
+ instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
+ instruct_s2s_train,
+ audio_key="question_audio",
+ text_key="answer",
+ )
+
+ instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
+
+ if self.args.huggingface_dataset_path_or_name is not None:
+ librispeech_path = self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
+ else:
+ librispeech_path = "fixie-ai/librispeech_asr"
+ # 148_688
+ librispeech_other = load_dataset(
+ librispeech_path, "other", split="train.500", streaming=True
+ )
+ # 104_014
+ librispeech_clean_360 = load_dataset(
+ librispeech_path, "clean", split="train.360", streaming=True
+ )
+ # 28_539
+ librispeech_clean_100 = load_dataset(
+ librispeech_path, "clean", split="train.100", streaming=True
+ )
+
+ librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
+ librispeech_clean_100,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ librispeech_other_cuts = CutSet.from_huggingface_dataset(
+ librispeech_other,
+ audio_key="audio",
+ text_key="text",
+ )
+
+ librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
+ librispeech_clean_360,
+ audio_key="audio",
+ text_key="text",
+ )
+
+
+ return CutSet.mux(
+ librispeech_other_cuts,
+ VoiceAssistant_cuts,
+ ultrachat_cuts,
+ librispeech_clean_360_cuts,
+ instruct_s2s_train_cuts,
+ librispeech_clean_100_cuts,
+ weights=[
+ 148688,
+ len(VoiceAssistant_cuts),
+ len(ultrachat_cuts),
+ 104014,
+ 423_000,
+ 28539,
+ ],
+ )
+
+ @lru_cache()
+ def train_cuts_emilia_en(self) -> CutSet:
+ logging.info("About to get train cuts")
+ data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
+ # if self.args.huggingface_dataset_path_or_name is not None:
+ # data_path = self.args.huggingface_dataset_path_or_name + "/emilia_en"
+ # else:
+ # data_path = "yuekai/emilia_en"
+
+ emilia_en_data = load_dataset(
+ data_path, split="train", streaming=True
+ )
+
+ def update_wav_path(example):
+ sampling_rate = 16000 # From current_features
+ duration = 1 # seconds, arbitrary duration for random audio
+ num_channels = 1 # mono
+ sample_width = 2 # 2 bytes = 16-bit audio
+
+ num_frames = int(duration * sampling_rate)
+
+ # Generate random bytes for the PCM data part
+ # This will be random noise, but structurally valid for a WAV file
+ pcm_data = bytes([random.randint(0, 255) for _ in range(num_frames * num_channels * sample_width)])
+
+ # Create a WAV file in memory
+ audio_buffer = io.BytesIO()
+ with wave.open(audio_buffer, 'wb') as wf:
+ wf.setnchannels(num_channels)
+ wf.setsampwidth(sample_width)
+ wf.setframerate(sampling_rate)
+ wf.writeframes(pcm_data) # writeframes expects bytes
+
+ example["wav"] = audio_buffer.getvalue()
+ return example
+
+ emilia_en_data = emilia_en_data.map(update_wav_path)
+ current_features = Features({
+ 'id': Value('string'),
+ 'text': Value('string'),
+ 'duration': Value('float'),
+ 'language': Value('string'),
+ 'dnsmos': Value('float'),
+ 'speech_token': Sequence(Value('int32')),
+ 'wav': Audio(sampling_rate=16000)
+
+ })
+ emilia_en_data = emilia_en_data.rename_column("code", "speech_token")
+ emilia_en_data = emilia_en_data.cast(current_features)
+
+ emilia_en_train_cuts = CutSet.from_huggingface_dataset(
+ emilia_en_data, # Adjusted from instruct_s2s_train
+ audio_key="wav",
+ text_key="text",
+ )
+ return emilia_en_train_cuts
\ No newline at end of file
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py
new file mode 100755
index 000000000..8e915cf26
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py
@@ -0,0 +1,759 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
+# Fangjun Kuang,
+# Wei Kang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+# Command for decoding using fine-tuned models:
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
+# Cosyvoice pretrained model for speech token2wav module
+huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
+# Qwen Pretrained model
+huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
+# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
+huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
+
+cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
+python3 ./qwen_omni/decode.py \
+--max-duration 1 \
+--exp-dir $exp_dir \
+--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
+--epoch 999 --avg 1 \
+--manifest-dir data/fbank \
+--use-flash-attn True \
+--method e2e-epoch10_speech2speech \
+--enable-speech-output True \
+--token2wav-path models/CosyVoice-300M-SFT \
+--use-lora True
+"""
+
+import argparse
+import logging
+import sys
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import soundfile as sf
+import torch
+import torch.nn as nn
+import transformers
+import whisper
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
+from cosyvoice.utils.file_utils import load_wav
+from data_module import AsrDataModule
+from lhotse.cut import Cut
+from model import SPEECH_LLM, EncoderProjector
+from peft import LoraConfig, get_peft_model
+from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
+from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
+from utils import AttributeDict, setup_logger, store_transcripts, write_error_stats
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+
+
+def audio_decode_cosyvoice2(
+ audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
+):
+ """
+ Generate audio from tokens with optional tone and prompt embedding.
+
+ Args:
+ audio_tokens (list): List of audio tokens to be processed.
+ model_config: Configuration object containing vocab settings.
+ codec_decoder: Codec decoder for generating audio.
+ tone_dir (str): The tone directory or setting.
+ audio_prompt_path (str, optional): Path to the audio prompt file. Required when tone_dir is not "default_tone".
+ code_layer (int, optional): Number of code layers. Defaults to 1.
+ num_latency_tokens (int, optional): Number of latency tokens to ignore. Defaults to 0.
+ speed (float, optional): Speed factor for audio generation. Defaults to 1.0.
+
+ Returns:
+ torch.Tensor: Generated audio waveform.
+ """
+ model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
+ "empty", prompt_text, prompt_speech_16k, 24000
+ )
+ tts_mel, _ = codec_decoder.model.flow.inference(
+ token=audio_tokens.to(codec_decoder.model.device),
+ token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+ codec_decoder.model.device
+ ),
+ prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
+ codec_decoder.model.device
+ ),
+ prompt_token_len=torch.tensor(
+ [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
+ codec_decoder.model.device
+ ),
+ prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
+ codec_decoder.model.device
+ ),
+ embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
+ finalize=True,
+ )
+
+ audio_hat, _ = codec_decoder.model.hift.inference(
+ speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+ )
+
+ return audio_hat
+
+
+def audio_decode_cosyvoice(audio_tokens, codec_decoder):
+ """
+ Generate audio from tokens with optional tone and prompt embedding.
+
+ Args:
+ audio_tokens (list): List of audio tokens to be processed.
+ codec_decoder: Codec decoder for generating audio.
+
+ Returns:
+ torch.Tensor: Generated audio waveform.
+ """
+ flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
+ prompt_speech_feat = torch.zeros(1, 0, 80)
+ tts_mel, _ = codec_decoder.model.flow.inference(
+ token=audio_tokens.to(codec_decoder.model.device),
+ token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+ codec_decoder.model.device
+ ),
+ prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
+ prompt_token_len=torch.tensor(
+ [flow_prompt_speech_token.shape[1]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
+ prompt_feat_len=torch.tensor(
+ [prompt_speech_feat.shape[1]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ embedding=flow_embedding.to(codec_decoder.model.device),
+ flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
+ )
+
+ audio_hat, _ = codec_decoder.model.hift.inference(
+ speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+ )
+
+ return audio_hat
+
+
+def get_model(params, device):
+ """Load and prepare the speech-to-speech model."""
+ if params.remove_whisper_encoder_input_length_restriction:
+ replace_whisper_encoder_forward()
+
+ whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
+ speech_encoder = whisper_model.encoder
+ speech_encoder_dim = whisper_model.dims.n_audio_state
+ tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
+
+ if params.use_flash_attn:
+ attn_implementation = "flash_attention_2"
+ # torch_dtype=torch.bfloat16 FIX ME
+ torch_dtype = torch.float16
+ tokenizer.padding_side = "left"
+
+ else:
+ attn_implementation = "eager"
+ torch_dtype = torch.float16
+ tokenizer.padding_side = "right"
+
+ llm = AutoModelForCausalLM.from_pretrained(
+ params.llm_path_or_name,
+ attn_implementation=attn_implementation,
+ torch_dtype=torch_dtype,
+ )
+ if params.use_lora:
+ lora_config = LoraConfig(
+ r=64,
+ lora_alpha=16,
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "up_proj",
+ "gate_proj",
+ "down_proj",
+ ],
+ task_type="CAUSAL_LM",
+ )
+ llm = get_peft_model(llm, lora_config)
+ llm.print_trainable_parameters()
+
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
+ tokenizer.add_special_tokens(special_tokens_dict)
+ llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
+ llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
+ llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+
+ llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
+ DEFAULT_SPEECH_TOKEN
+ )
+
+ encoder_projector = EncoderProjector(
+ speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
+ )
+
+ if params.enable_speech_output:
+ # Determine attn_implementation and torch_dtype based on use_flash_attn
+ if params.use_flash_attn:
+ attn_implementation = "flash_attention_2"
+ torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
+ else:
+ attn_implementation = "eager"
+ torch_dtype = torch.float16
+
+ # TODO: FIX ME
+ # codec_vocab_size = 4096 + 4
+ codec_vocab_size = 6561 + 4
+ config = Qwen2Config(
+ vocab_size=codec_vocab_size,
+ hidden_size=1024,
+ num_hidden_layers=12,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ intermediate_size=2048,
+ max_position_embeddings=4096,
+ )
+
+ codec_lm = AutoModelForCausalLM.from_config(
+ config=config,
+ attn_implementation=attn_implementation,
+ torch_dtype=torch_dtype,
+ )
+
+ codec_lm.resize_token_embeddings(codec_vocab_size)
+ codec_lm.vocab_size = codec_vocab_size
+ codec_lm.config.pad_token_id = codec_vocab_size - 1
+ codec_lm.config.eos_token_id = codec_vocab_size - 2
+ codec_lm.config.bos_token_id = codec_vocab_size - 3
+ codec_lm.config.mask_token_id = codec_vocab_size - 4
+ else:
+ codec_lm = None
+
+ model = SPEECH_LLM(
+ speech_encoder,
+ llm,
+ encoder_projector,
+ codec_lm,
+ codec_lm_padding_side="left" if params.use_flash_attn else "right",
+ )
+
+ if params.avg > 1:
+ start = params.epoch - params.avg + 1
+ assert start >= 1, start
+ checkpoint = torch.load(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
+ )
+ assert "model" not in checkpoint
+ # deepspeed converted checkpoint only contains model state_dict
+ filenames = [
+ f"{params.exp_dir}/epoch-{epoch}.pt"
+ for epoch in range(start, params.epoch + 1)
+ ]
+ avg_checkpoint = average_checkpoints(filenames)
+ model.load_state_dict(avg_checkpoint, strict=False)
+
+ filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save(avg_checkpoint, filename)
+ else:
+ checkpoint = torch.load(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
+ )
+ model.load_state_dict(checkpoint, strict=False)
+
+ model.to(device)
+ model.eval()
+ return model, tokenizer
+
+
+def average_checkpoints(
+ filenames: List[Path], device: torch.device = torch.device("cpu")
+) -> dict:
+ """Average a list of checkpoints.
+ The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
+
+ Args:
+ filenames:
+ Filenames of the checkpoints to be averaged. We assume all
+ checkpoints are saved by :func:`save_checkpoint`.
+ device:
+ Move checkpoints to this device before averaging.
+ Returns:
+ Return a dict (i.e., state_dict) which is the average of all
+ model state dicts contained in the checkpoints.
+ """
+ n = len(filenames)
+
+ if "model" in torch.load(filenames[0], map_location=device):
+ avg = torch.load(filenames[0], map_location=device)["model"]
+ else:
+ avg = torch.load(filenames[0], map_location=device)
+
+ # Identify shared parameters. Two parameters are said to be shared
+ # if they have the same data_ptr
+ uniqued: Dict[int, str] = dict()
+
+ for k, v in avg.items():
+ v_data_ptr = v.data_ptr()
+ if v_data_ptr in uniqued:
+ continue
+ uniqued[v_data_ptr] = k
+
+ uniqued_names = list(uniqued.values())
+
+ for i in range(1, n):
+ if "model" in torch.load(filenames[i], map_location=device):
+ state_dict = torch.load(filenames[i], map_location=device)["model"]
+ else:
+ state_dict = torch.load(filenames[i], map_location=device)
+ for k in uniqued_names:
+ avg[k] += state_dict[k]
+
+ for k in uniqued_names:
+ if avg[k].is_floating_point():
+ avg[k] /= n
+ else:
+ avg[k] //= n
+
+ return avg
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=-1,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="beam-search",
+ help="""Decoding method.
+ Supported values are:
+ - beam-search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=1,
+ help="beam size for beam search decoding",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="whisper/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--token2wav-path",
+ type=str,
+ default="/workspace/CosyVoice-300M-SFT",
+ help="The path to the token2wav model",
+ )
+
+ parser.add_argument(
+ "--prompt_text",
+ type=str,
+ default="Romeo and Juliet might be the most famous act of William Shakespeare.",
+ help="The prompt text",
+ )
+
+ parser.add_argument(
+ "--prompt_speech_path",
+ type=str,
+ default="./assets/common_voice_en_2586258.wav",
+ help="The path to the prompt speech",
+ )
+
+ add_model_arguments(parser)
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict({})
+ return params
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ tokenizer: AutoTokenizer,
+ token2wav_model: nn.Module,
+ batch: dict,
+) -> Dict[str, List[List[int]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: "beam-search"
+ - value: A list of lists. Each sublist is a list of token IDs.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
+ Returns:
+ Return a dict, whose key may be "beam-search".
+ """
+
+ def preprocess(
+ messages,
+ tokenizer: transformers.PreTrainedTokenizer,
+ ) -> Dict:
+ """Preprocesses the data for supervised fine-tuning."""
+ texts = []
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ for i, msg in enumerate(messages):
+ texts.append(
+ tokenizer.apply_chat_template(
+ msg,
+ tokenize=True,
+ add_generation_prompt=False,
+ chat_template=TEMPLATE,
+ padding="longest",
+ truncation=False,
+ )
+ )
+ max_len_texts = max([len(text) for text in texts])
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
+ else:
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
+
+ input_ids = torch.tensor(texts, dtype=torch.int)
+
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+
+ return input_ids, attention_mask
+
+ dtype = torch.float32
+ device = model.llm.device
+
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ feature = feature.to(device, dtype=dtype).transpose(1, 2)
+ if not params.remove_whisper_encoder_input_length_restriction:
+ T = 3000
+ if feature.shape[2] < T:
+ feature = torch.cat(
+ [
+ feature,
+ torch.zeros(
+ feature.shape[0], feature.shape[1], T - feature.shape[2]
+ ).to(device, dtype=dtype),
+ ],
+ 2,
+ )
+
+ # chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
+
+ # questions_with_history = [
+ # cut.custom["question"] for cut in batch["supervisions"]["cut"]
+ # ]
+ # history_contexts = [
+ # question.rsplit(":", 1)[0].strip() for question in questions_with_history
+ # ]
+ # last_questions = [
+ # question.split(": ")[-1].strip() for question in questions_with_history
+ # ]
+ # messages = []
+ # for i, total_round in enumerate(chat_rounds):
+ # message = []
+ # if total_round > 1:
+ # history_question_answer = history_contexts[i].split("USER:")
+ # history_question_answer = [item for item in history_question_answer if item]
+ # for j in range(total_round - 1):
+ # question_answer = history_question_answer[j].split("ASSISTANT:")
+ # message += [
+ # {"role": "user", "content": question_answer[0].strip()},
+ # {"role": "assistant", "content": question_answer[1].strip()},
+ # ]
+ # message += [
+ # {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
+ # {"role": "assistant", "content": ""},
+ # ]
+ # print(f"message: {message}, batch_size {len(chat_rounds)}")
+ # messages.append(message)
+ messages = []
+ for i in range(len(batch["supervisions"]["cut"])):
+ message = [
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
+ {"role": "assistant", "content": ""},
+ ]
+ messages.append(message)
+ input_ids, attention_mask = preprocess(messages, tokenizer)
+ if params.enable_speech_output:
+ generated_ids, generated_speech_output = model.decode_with_speech_output(
+ feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
+ )
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+ generated_speech_output = [
+ generated_speech_output
+ ] # WAR: only support batch = 1 for now
+ for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
+ speech_file_name = params.log_dir / f"{cut_id}.wav"
+ # audio_tokens = [token for token in audio_tokens if token < 4096]
+ audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
+ if "CosyVoice2" in params.token2wav_path:
+ prompt_speech_16k = load_wav(params.prompt_speech_path, 16000)
+ audio_hat = audio_decode_cosyvoice2(
+ audio_tokens,
+ params.prompt_text,
+ prompt_speech_16k,
+ token2wav_model,
+ )
+ sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
+ else:
+ audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
+ sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
+ else:
+ generated_ids = model.decode(
+ feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
+ )
+ hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
+ print(f"hyps: {hyps}")
+ return {"beam-search": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ tokenizer: AutoTokenizer,
+ token2wav_model: nn.Module,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ The dataloader.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ Returns:
+ Return a dict, whose key may be "beam-search".
+ """
+ results = []
+
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ # questions_with_history = [
+ # cut.custom["question"] for cut in batch["supervisions"]["cut"]
+ # ]
+ # texts = [
+ # question.split(": ")[-1].strip()
+ # for question in questions_with_history
+ # ]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ token2wav_model=token2wav_model,
+ batch=batch,
+ tokenizer=tokenizer,
+ )
+
+ for lm_scale, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ print(f"ref: {ref_text}")
+ print(f"hyp: {''.join(hyp_words)}")
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[lm_scale].extend(this_batch)
+
+ num_cuts += len(batch["supervisions"]["text"])
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+
+ enable_log = True
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.log_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ if enable_log:
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results_char = []
+ for res in results:
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
+ )
+ test_set_wers[key] = wer
+
+ if enable_log:
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.log_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tCER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+ params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
+ params.log_dir.mkdir(parents=True, exist_ok=True)
+ setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
+
+ logging.info("Decoding started")
+ logging.info(params)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+
+ logging.info(f"device: {device}")
+
+ model, tokenizer = get_model(params, device)
+ if "CosyVoice2" in params.token2wav_path:
+ token2wav_model = CosyVoice2(
+ params.token2wav_path, load_jit=False, load_trt=False, fp16=False
+ )
+ else:
+ token2wav_model = CosyVoice(
+ params.token2wav_path, load_jit=False, load_trt=False, fp16=False
+ )
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ args.return_cuts = True
+ data_module = AsrDataModule(args)
+
+ def remove_long_utt(c: Cut):
+ # Keep only utterances with duration in 30 seconds
+ #
+ if c.duration > 30.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+ return True
+
+ # TODO: FIX ME
+ # test_sets_cuts = data_module.test_cuts_belle()
+ test_sets_cuts = data_module.test_cuts_en_vocalnet()
+ test_sets = test_sets_cuts.keys()
+ test_dls = [
+ data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
+ for cuts_name in test_sets
+ ]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ token2wav_model=token2wav_model,
+ tokenizer=tokenizer,
+ )
+
+ save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+ logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py
new file mode 100644
index 000000000..dd69fce10
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py
@@ -0,0 +1,256 @@
+# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
+# 2025 (authors: Yuekai Zhang)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
+""" Example Usage
+split=test_zh
+llm_path=f5-tts/exp_zh/checkpoint-805000
+huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
+model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
+huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
+vocoder=./bigvgan_v2_24khz_100band_256x
+torchrun --nproc_per_node=2 \
+ f5-tts/infer_dist.py \
+ --output_dir $output_dir \
+ --batch_size 1 \
+ --num_workers 2 \
+ --llm-model-name-or-path $llm_path \
+ --flow-matching-model-path $model_path \
+ --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
+ --use-cosyvoice-semantic-token True \
+ --vocoder-dir $vocoder \
+ --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
+ --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
+"""
+
+import argparse
+import json
+import os
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import whisper
+from datasets import load_dataset
+from torch.utils.data import DataLoader, Dataset, DistributedSampler
+from tqdm import tqdm
+from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
+from transformers import AutoTokenizer
+from web_demo import get_model
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
+# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
+# sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+try:
+ torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+ pass
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="extract speech code")
+ parser.add_argument(
+ "--split-name",
+ type=str,
+ default="test",
+ help="huggingface dataset split name",
+ )
+ parser.add_argument(
+ "--subset-name",
+ type=str,
+ default="commoneval",
+ help="subset name",
+ )
+ parser.add_argument(
+ "--output-dir", required=True, type=str, help="dir to save result"
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=1,
+ help="batch size (per-device) for inference",
+ )
+ parser.add_argument(
+ "--num-workers", type=int, default=2, help="workers for dataloader"
+ )
+ parser.add_argument(
+ "--prefetch", type=int, default=2, help="prefetch for dataloader"
+ )
+ parser.add_argument(
+ "--checkpoint-path",
+ type=str,
+ default=None,
+ help="Checkpoint name or path, default to %(default)r",
+ )
+ # parser.add_argument(
+ # "--top-k",
+ # type=int,
+ # default=50,
+ # help="top k for sampling",
+ # )
+ # parser.add_argument(
+ # "--top-p",
+ # type=float,
+ # default=0.95,
+ # help="top p for sampling",
+ # )
+ # parser.add_argument(
+ # "--temperature",
+ # type=float,
+ # default=0.8,
+ # help="temperature for sampling",
+ # )
+ add_model_arguments(parser)
+ args = parser.parse_args()
+ return args
+
+
+def init_distributed():
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ rank = int(os.environ.get("RANK", 0))
+ print(
+ "Inference on multiple gpus, this gpu {}".format(local_rank)
+ + ", rank {}, world_size {}".format(rank, world_size)
+ )
+ torch.cuda.set_device(local_rank)
+ dist.init_process_group("nccl")
+ return world_size, local_rank, rank
+
+
+def preprocess(
+ messages,
+ tokenizer,
+):
+ """Preprocesses the data for supervised fine-tuning."""
+ texts = []
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ for i, msg in enumerate(messages):
+ texts.append(
+ tokenizer.apply_chat_template(
+ msg,
+ tokenize=True,
+ add_generation_prompt=False,
+ chat_template=TEMPLATE,
+ padding="longest",
+ truncation=False,
+ )
+ )
+ max_len_texts = max([len(text) for text in texts])
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
+ else:
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
+
+ input_ids = torch.tensor(texts, dtype=torch.int)
+
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+
+ return input_ids, attention_mask
+
+
+def custom_collate(batch):
+ assert len(batch) == 1
+ audio = batch[0]["audio"]
+ assert audio["sampling_rate"] == 16000
+ result = {"audio": audio["array"]}
+ for keys in batch[0].keys():
+ if keys != "audio":
+ result[keys] = batch[0][keys]
+ return result
+
+
+def main():
+ args = get_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ assert torch.cuda.is_available()
+ world_size, local_rank, rank = init_distributed()
+ device = torch.device(f"cuda:{local_rank}")
+
+ dataset = load_dataset(
+ "hlt-lab/voicebench",
+ args.subset_name,
+ split=args.split_name,
+ trust_remote_code=True,
+ )
+
+ model, tokenizer = get_model(args)
+ # tokenizer = AutoTokenizer.from_pretrained(args.llm_path_or_name)
+ sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ sampler=sampler,
+ shuffle=False,
+ num_workers=args.num_workers,
+ prefetch_factor=args.prefetch,
+ collate_fn=custom_collate,
+ )
+
+ total_steps = len(dataset)
+
+ if rank == 0:
+ progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
+
+ message = [
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
+ {"role": "assistant", "content": ""},
+ ]
+ input_ids, attention_mask = preprocess([message], tokenizer)
+ results_jsonl_file = open(
+ os.path.join(
+ args.output_dir,
+ f"results-{args.subset_name}-{args.split_name}-{rank}-audio.jsonl",
+ ),
+ "w",
+ )
+ for batch in dataloader:
+ audio = batch["audio"]
+ audio = torch.from_numpy(audio).to(device).to(torch.float32)
+ fbank = whisper.log_mel_spectrogram(audio, device=device)
+ fbank = fbank.unsqueeze(0)
+ generated_ids = model.decode(
+ fbank, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
+ )
+ hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+
+ result_dict = {}
+ for key in batch.keys():
+ if key != "audio":
+ result_dict[key] = batch[key]
+ result_dict["response"] = hyps[0]
+ json.dump(result_dict, results_jsonl_file)
+ results_jsonl_file.write("\n")
+
+ if rank == 0:
+ progress_bar.update(world_size * args.batch_size)
+
+ if rank == 0:
+ progress_bar.close()
+
+ dist.barrier()
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_tts.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_tts.py
new file mode 100755
index 000000000..c9383232c
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_tts.py
@@ -0,0 +1,310 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
+# Qwen Pretrained model
+huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
+
+torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 50 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --manifest-dir data/fbank \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
+"""
+
+import argparse
+import copy
+import logging
+import os
+import random
+import sys
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import soundfile as sf
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import transformers
+from cosyvoice.cli.cosyvoice import CosyVoice2
+from datasets import Audio, load_dataset
+from decode import audio_decode_cosyvoice2
+from label_smoothing import LabelSmoothingLoss
+from lhotse.utils import fix_random_seed
+from model import IGNORE_TOKEN_ID, SPEECH_LLM
+from peft import LoraConfig, get_peft_model
+from torch import Tensor
+from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.tensorboard import SummaryWriter
+from train import add_model_arguments, add_training_arguments, get_model, get_params
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ Qwen2Config,
+ Qwen2ForCausalLM,
+)
+from utils import ( # filter_uneven_sized_batch,
+ AttributeDict,
+ MetricsTracker,
+ get_local_rank,
+ get_rank,
+ get_world_size,
+ setup_logger,
+ str2bool,
+)
+
+# sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+DEFAULT_SPEECH_TOKEN = ""
+try:
+ torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+ pass
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=1,
+ help="The batch size to use.",
+ )
+
+ parser.add_argument(
+ "--split-name",
+ type=str,
+ default="test_en",
+ choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
+ help="huggingface dataset split name",
+ )
+ parser.add_argument(
+ "--token2wav-path",
+ type=str,
+ default="/workspace/CosyVoice-300M-SFT",
+ help="The path to the token2wav model",
+ )
+
+ add_model_arguments(parser)
+ add_training_arguments(parser)
+
+ return parser
+
+
+def preprocess(
+ messages,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """Preprocesses the data for supervised fine-tuning."""
+ texts = []
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ for i, msg in enumerate(messages):
+ texts.append(
+ tokenizer.apply_chat_template(
+ msg,
+ tokenize=True,
+ chat_template=TEMPLATE,
+ add_generation_prompt=False,
+ padding="longest", # FIX me change padding to longest
+ truncation=False,
+ )
+ )
+ if len(texts) != len(messages):
+ logging.warning(f"Remove too long text, {messages} ")
+ max_len_texts = max([len(text) for text in texts])
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
+ else:
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
+ input_ids = torch.tensor(texts, dtype=torch.int)
+
+ target_ids = input_ids.clone()
+ target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
+ # mask all tokens before token_id with IGNORE_TOKEN_ID
+ # first get the indices of the tokens
+ mask_prompt = True
+ if mask_prompt:
+ default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
+ mask_indices = torch.where(input_ids == default_speech_token_id)
+ for i in range(mask_indices[0].size(0)):
+ row = mask_indices[0][i]
+ col = mask_indices[1][i]
+ # + 2 to skip: 'assistant', '\n'
+ # WAR: TODO FIXME check qwen3
+ # THIS IS THE ONLY DIFFERENCE FROM preprocess
+ target_ids[row, : col + 6] = IGNORE_TOKEN_ID
+ target_ids[row, col] = default_speech_token_id
+ # remove default_speech_token_id from target_ids and input_ids
+ batch_size = target_ids.size(0)
+
+ target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
+ input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
+
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+ return input_ids, attention_mask, target_ids
+
+
+def data_collator(batch):
+ prompt_texts, prompt_speech_16k, messages, ids, target_texts = [], [], [], [], []
+ for i, item in enumerate(batch):
+ # speech_tokens.append(item["prompt_audio_cosy2_tokens"])
+ message_list_item = []
+ message_list_item += [
+ {
+ "role": "user",
+ "content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}",
+ },
+ {"role": "assistant", "content": ""},
+ ]
+ messages.append(message_list_item)
+ target_texts.append(item["target_text"])
+
+ ids.append(item["id"])
+ prompt_texts.append(item["prompt_text"])
+ speech_org = item["prompt_audio"]
+
+ speech_org = torch.tensor(speech_org["array"], dtype=torch.float32).unsqueeze(0)
+ speech_org = speech_org.mean(dim=0, keepdim=True)
+ prompt_speech_16k.append(speech_org)
+
+ # resample to 16k
+
+ return {
+ "prompt_texts": prompt_texts,
+ "target_texts": target_texts,
+ "prompt_speech_16k": prompt_speech_16k,
+ "messages": messages,
+ "ids": ids,
+ }
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ params.log_dir = Path(params.exp_dir) / "log-results-wav"
+ params.log_dir.mkdir(parents=True, exist_ok=True)
+
+ fix_random_seed(params.seed)
+
+ if rank == 0:
+ setup_logger(f"{params.exp_dir}/log/log-decode-tts")
+ logging.info(params)
+ logging.info("About to create model")
+ model, tokenizer = get_model(params)
+ if torch.cuda.is_available():
+ device = torch.device("cuda", get_local_rank())
+ else:
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+ model.to(device)
+
+ dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
+ dataset = dataset.cast_column("prompt_audio", Audio(sampling_rate=16000))
+
+ sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
+ data_loader = DataLoader(
+ dataset,
+ batch_size=params.batch_size,
+ sampler=sampler,
+ shuffle=False,
+ num_workers=1,
+ prefetch_factor=1,
+ collate_fn=data_collator,
+ )
+ token2wav_model = CosyVoice2(
+ params.token2wav_path, load_jit=False, load_trt=False, fp16=False
+ )
+ for batch in data_loader:
+ messages = batch["messages"]
+ prompt_texts = batch["prompt_texts"]
+ prompt_speech_16k = batch["prompt_speech_16k"]
+ target_texts = batch["target_texts"]
+ ids = batch["ids"]
+ input_ids, attention_mask, _ = preprocess(messages, tokenizer)
+ generated_ids, generated_speech_output = model.decode_with_speech_output(
+ None, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
+ )
+ generated_speech_output = [
+ generated_speech_output
+ ] # WAR: only support batch = 1 for now
+ for cut_id, audio_tokens, prompt_text, prompt_speech, target_text in zip(
+ ids, generated_speech_output, prompt_texts, prompt_speech_16k, target_texts
+ ):
+ speech_file_name = params.log_dir / f"{cut_id}.wav"
+ # save target_text to file
+ with open(params.log_dir / f"{cut_id}.txt", "w") as f:
+ f.write(f"{target_text}\n")
+ audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
+ if "CosyVoice2" in params.token2wav_path:
+ audio_hat = audio_decode_cosyvoice2(
+ audio_tokens,
+ prompt_text,
+ prompt_speech,
+ token2wav_model,
+ )
+ sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
+
+ logging.info("Done!")
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ torch.set_num_threads(1)
+ # torch.set_num_interop_threads(1)
+ warnings.filterwarnings("ignore", category=FutureWarning)
+ run(rank=rank, world_size=world_size, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json
new file mode 120000
index 000000000..4fbacea32
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json
@@ -0,0 +1 @@
+../../ASR_LLM/whisper_llm_zh/ds_config_zero1.json
\ No newline at end of file
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py
new file mode 100644
index 000000000..baec602bb
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py
@@ -0,0 +1,838 @@
+from typing import List, Tuple
+
+import torch
+from torch import nn
+from torchmetrics.classification import MulticlassAccuracy
+from transformers.trainer_pt_utils import LabelSmoother
+
+IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+import logging
+
+
+class EncoderProjector(nn.Module):
+ """
+ The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
+ Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
+ Args:
+ encoder_dim (:obj:`int`): The dimension of the encoder outputs.
+ llm_dim (:obj:`int`): The dimension of the language model.
+ downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
+ """
+
+ def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
+ super().__init__()
+ self.downsample_rate = downsample_rate
+ self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
+ self.relu = nn.ReLU()
+ self.linear2 = nn.Linear(llm_dim, llm_dim)
+
+ def forward(self, x):
+
+ batch_size, seq_len, feat_dim = x.size()
+ num_frames_to_discard = seq_len % self.downsample_rate
+ if num_frames_to_discard > 0:
+ x = x[:, :-num_frames_to_discard, :]
+ seq_len = x.size(1)
+
+ x = x.contiguous()
+ x = x.view(
+ batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
+ )
+
+ x = self.linear1(x)
+ x = self.relu(x)
+ x = self.linear2(x)
+ return x
+
+
+class SPEECH_LLM(nn.Module):
+ """
+ The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
+ The encoder is used to extract speech features from the input speech signal.
+ The encoder projector is used to project the encoder outputs to the same dimension as the language model.
+ The language model is used to generate the text from the speech features.
+ Args:
+ encoder (:obj:`nn.Module`): The encoder module.
+ llm (:obj:`nn.Module`): The language model module.
+ encoder_projector (:obj:`nn.Module`): The encoder projector module.
+ """
+
+ def __init__(
+ self,
+ encoder: nn.Module = None,
+ llm: nn.Module = None,
+ encoder_projector: nn.Module = None,
+ codec_lm: nn.Module = None,
+ codec_lm_padding_side: str = "left",
+ teacher_llm: nn.Module = None,
+ kl_temperature: float = 2.0,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.llm = llm
+ self.encoder_projector = encoder_projector
+ self.codec_lm = codec_lm
+ if self.codec_lm:
+ self.speech_token_projector = nn.Linear(
+ self.llm.config.hidden_size + self.llm.config.hidden_size,
+ self.codec_lm.config.hidden_size,
+ )
+ self.codec_lm_head = nn.Linear(
+ self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
+ )
+ self.speech_token_projector = self.speech_token_projector.to(
+ dtype=torch.float16
+ )
+ self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16)
+ self.loss_fct = torch.nn.CrossEntropyLoss()
+ self.codec_lm_padding_side = codec_lm_padding_side
+
+ self.audio_accuracy_metric = MulticlassAccuracy(
+ self.codec_lm.vocab_size,
+ top_k=10,
+ average="micro",
+ multidim_average="global",
+ ignore_index=IGNORE_TOKEN_ID,
+ )
+ if teacher_llm is not None:
+ self.teacher_llm = teacher_llm
+ self.kl_temperature = kl_temperature
+
+ def _merge_input_ids_with_speech_features(
+ self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
+ ):
+ """
+ Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
+ with the speech features and padding the input_ids to the maximum length of the speech features.
+ Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
+ Args:
+ speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
+ inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
+ input_ids (:obj:`torch.Tensor`): The input ids to merge.
+ attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
+ labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
+ Returns:
+ :obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
+ """
+ num_speechs, speech_len, embed_dim = speech_features.shape
+ batch_size, sequence_length = input_ids.shape
+ left_padding = not torch.sum(
+ input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
+ )
+ # 1. Create a mask to know where special speech tokens are
+ special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
+ num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
+ # Compute the maximum embed dimension
+ max_embed_dim = (
+ num_special_speech_tokens.max() * (speech_len - 1)
+ ) + sequence_length
+ batch_indices, non_speech_indices = torch.where(
+ input_ids != self.llm.config.default_speech_token_id
+ )
+
+ # 2. Compute the positions where text should be written
+ # Calculate new positions for text tokens in merged speech-text sequence.
+ # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
+ # `torch.cumsum` computes how each speech token shifts subsequent text token positions.
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
+ new_token_positions = (
+ torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
+ )
+ nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
+ if left_padding:
+ new_token_positions += nb_speech_pad[:, None] # offset for left padding
+ text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
+
+ # 3. Create the full embedding, already padded to the maximum position
+ final_embedding = torch.zeros(
+ batch_size,
+ max_embed_dim,
+ embed_dim,
+ dtype=inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ )
+ final_attention_mask = torch.zeros(
+ batch_size,
+ max_embed_dim,
+ dtype=attention_mask.dtype,
+ device=inputs_embeds.device,
+ )
+ if labels is not None:
+ final_labels = torch.full(
+ (batch_size, max_embed_dim),
+ IGNORE_TOKEN_ID,
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
+ # set the corresponding tensors into their correct target device.
+ target_device = inputs_embeds.device
+ batch_indices, non_speech_indices, text_to_overwrite = (
+ batch_indices.to(target_device),
+ non_speech_indices.to(target_device),
+ text_to_overwrite.to(target_device),
+ )
+ attention_mask = attention_mask.to(target_device)
+
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
+ batch_indices, non_speech_indices
+ ]
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
+ batch_indices, non_speech_indices
+ ]
+ if labels is not None:
+ final_labels[batch_indices, text_to_overwrite] = labels[
+ batch_indices, non_speech_indices
+ ]
+
+ # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
+ speech_to_overwrite = torch.full(
+ (batch_size, max_embed_dim),
+ True,
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ speech_to_overwrite[batch_indices, text_to_overwrite] = False
+ speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
+ :, None
+ ].to(target_device)
+
+ if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
+ raise ValueError(
+ f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
+ f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
+ )
+
+ final_embedding[speech_to_overwrite] = (
+ speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
+ )
+ final_attention_mask |= speech_to_overwrite
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
+ (final_attention_mask == 0), 1
+ )
+
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
+ batch_indices, pad_indices = torch.where(
+ input_ids == self.llm.config.pad_token_id
+ )
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
+
+ final_embedding[batch_indices, indices_to_mask] = 0
+
+ if labels is None:
+ final_labels = None
+
+ return final_embedding, final_attention_mask, final_labels, position_ids
+
+ def forward(
+ self,
+ fbank: torch.Tensor = None,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor = None,
+ labels: torch.LongTensor = None,
+ ):
+ encoder_outs = self.encoder(fbank)
+
+ speech_features = self.encoder_projector(encoder_outs)
+
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
+
+ (
+ inputs_embeds,
+ attention_mask,
+ labels,
+ _,
+ ) = self._merge_input_ids_with_speech_features(
+ speech_features, inputs_embeds, input_ids, attention_mask, labels
+ )
+
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
+ )
+
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ acc = compute_accuracy(
+ preds.detach()[:, :-1],
+ labels.detach()[:, 1:],
+ ignore_label=IGNORE_TOKEN_ID,
+ )
+ return model_outputs.loss, acc
+
+ def forward_kl_div(
+ self,
+ fbank: torch.Tensor = None,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor = None,
+ labels: torch.LongTensor = None,
+ teacher_input_ids: torch.LongTensor = None,
+ teacher_attention_mask: torch.Tensor = None,
+ teacher_labels: torch.LongTensor = None,
+ ):
+ encoder_outs = self.encoder(fbank)
+
+ speech_features = self.encoder_projector(encoder_outs)
+
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
+
+ (
+ inputs_embeds,
+ attention_mask,
+ labels,
+ _,
+ ) = self._merge_input_ids_with_speech_features(
+ speech_features, inputs_embeds, input_ids, attention_mask, labels
+ )
+
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
+ )
+
+ teacher_outputs = self.teacher_llm(
+ input_ids=teacher_input_ids,
+ attention_mask=teacher_attention_mask,
+ )
+
+ kl_loss = torch.nn.functional.kl_div(
+ torch.nn.functional.log_softmax(
+ model_outputs.logits[labels != -100] / self.kl_temperature,
+ dim=-1,
+ ),
+ torch.nn.functional.softmax(
+ teacher_outputs.logits[teacher_labels != -100] / self.kl_temperature,
+ dim=-1,
+ ),
+ reduction="batchmean",
+ )
+
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ teacher_preds = torch.argmax(teacher_outputs.logits, -1)
+ acc = compute_accuracy(
+ preds.detach()[:, :-1],
+ labels.detach()[:, 1:],
+ ignore_label=IGNORE_TOKEN_ID,
+ )
+ acc_teacher = compute_accuracy(
+ teacher_preds.detach()[:, :-1],
+ teacher_labels.detach()[:, 1:],
+ ignore_label=IGNORE_TOKEN_ID,
+ )
+ return kl_loss, acc, acc_teacher
+
+ def forward_with_speech_output(
+ self,
+ fbank: torch.Tensor = None,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor = None,
+ labels: torch.LongTensor = None,
+ speech_codec_ids: torch.LongTensor = None,
+ ):
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
+ if fbank is not None:
+ encoder_outs = self.encoder(fbank)
+ speech_features = self.encoder_projector(encoder_outs)
+ (
+ inputs_embeds,
+ attention_mask,
+ labels,
+ _,
+ ) = self._merge_input_ids_with_speech_features(
+ speech_features, inputs_embeds, input_ids, attention_mask, labels
+ )
+
+ input_seq_len = attention_mask.sum(dim=1) # shape, B
+ (
+ text_label_start_index_list,
+ text_input_start_index_list,
+ input_question_len_list,
+ ) = ([], [], [])
+ for i in range(labels.shape[0]):
+ input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
+ input_embeds_start_index = input_embeds_valid_index[0]
+ text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
+ text_labels_start_index = text_labels_valid_index[0]
+
+ assert (
+ input_seq_len[i]
+ == input_embeds_valid_index[-1] - input_embeds_start_index + 1
+ ), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
+ assert (
+ input_embeds_valid_index[-1] == text_labels_valid_index[-1]
+ ), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
+ input_question_len = text_labels_start_index - input_embeds_start_index
+ assert (
+ input_question_len
+ + text_labels_valid_index[-1]
+ - text_labels_start_index
+ + 1
+ == input_seq_len[i]
+ )
+ text_label_start_index_list.append(text_labels_start_index)
+ text_input_start_index_list.append(input_embeds_start_index)
+ input_question_len_list.append(input_question_len)
+
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ labels=labels,
+ output_hidden_states=True,
+ )
+ text_loss = model_outputs.loss
+ delay_step = 1
+ # prepare codec lm inputs
+ audio_codes_lens = [
+ len(x) + input_question_len_list[i] + delay_step + 1
+ for i, x in enumerate(speech_codec_ids)
+ ]
+ max_len_speech_codec = max(audio_codes_lens)
+
+ if self.codec_lm_padding_side == "right":
+ audio_codes = [
+ [self.codec_lm.config.mask_token_id]
+ * (input_question_len_list[i] + delay_step)
+ + [self.codec_lm.config.bos_token_id]
+ + x
+ + [self.codec_lm.config.pad_token_id]
+ * (max_len_speech_codec - audio_codes_lens[i])
+ for i, x in enumerate(speech_codec_ids)
+ ]
+ audio_labels = [
+ [self.codec_lm.config.pad_token_id]
+ * (input_question_len_list[i] + delay_step)
+ + x
+ + [self.codec_lm.config.eos_token_id]
+ + [self.codec_lm.config.pad_token_id]
+ * (max_len_speech_codec - audio_codes_lens[i])
+ for i, x in enumerate(speech_codec_ids)
+ ]
+ elif self.codec_lm_padding_side == "left":
+ audio_codes = [
+ [self.codec_lm.config.pad_token_id]
+ * (max_len_speech_codec - audio_codes_lens[i])
+ + [self.codec_lm.config.mask_token_id]
+ * (input_question_len_list[i] + delay_step)
+ + [self.codec_lm.config.bos_token_id]
+ + x
+ for i, x in enumerate(speech_codec_ids)
+ ]
+ audio_labels = [
+ [self.codec_lm.config.pad_token_id]
+ * (max_len_speech_codec - audio_codes_lens[i])
+ + [self.codec_lm.config.pad_token_id]
+ * (input_question_len_list[i] + delay_step)
+ + x
+ + [self.codec_lm.config.eos_token_id]
+ for i, x in enumerate(speech_codec_ids)
+ ]
+ audio_codes = torch.tensor(
+ audio_codes, dtype=torch.int64, device=input_ids.device
+ )
+ audio_labels = torch.tensor(
+ audio_labels, dtype=torch.int64, device=input_ids.device
+ )
+
+ audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
+ audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
+
+ # text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
+ text_input_embeds_list = []
+ for i in range(len(text_label_start_index_list)):
+ text_last_hidden = model_outputs.hidden_states[-1][
+ i,
+ text_input_start_index_list[i] : text_input_start_index_list[i]
+ + input_seq_len[i]
+ - 1,
+ ]
+ # text_last_hidden_lists.append(text_last_hidden)
+ text_embed = inputs_embeds[
+ i,
+ text_input_start_index_list[i]
+ + 1 : text_input_start_index_list[i]
+ + input_seq_len[i],
+ ] # exclude bos
+ # text_embeds_list.append(text_embed)
+
+ text_input_embeds = torch.cat(
+ [
+ text_last_hidden,
+ text_embed,
+ ],
+ dim=-1,
+ ) # shape, T, D1 + D2
+ text_input_embeds = self.speech_token_projector(
+ text_input_embeds
+ ) # shape, T, D_codec
+ text_input_embeds_list.append(text_input_embeds)
+
+ for i in range(audio_embeddings.shape[0]):
+ text_input_embeds = text_input_embeds_list[i]
+ if self.codec_lm_padding_side == "right":
+ audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
+ elif self.codec_lm_padding_side == "left":
+ start_idx = torch.where(
+ audio_codes[i] == self.codec_lm.config.mask_token_id
+ )[0][0]
+ start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
+ assert (
+ start_idx == start_idx_re_compute
+ ), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
+ if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
+ logging.warning(
+ f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}\naudio_codes_lens: {audio_codes_lens[i]}\ninput_question_len_list: {input_question_len_list[i]}\ninput_seq_len: {input_seq_len[i]}\n"
+ )
+ # breakpoint()
+ text_input_embeds = text_input_embeds[
+ : audio_embeddings.shape[1] - start_idx
+ ]
+ audio_embeddings[
+ i, start_idx : start_idx + text_input_embeds.shape[0]
+ ] += text_input_embeds
+
+ speech_outputs = self.codec_lm(
+ attention_mask=audio_attention_mask,
+ inputs_embeds=audio_embeddings,
+ return_dict=True,
+ output_hidden_states=True,
+ )
+ last_hidden_state = speech_outputs.hidden_states[-1].clone()
+
+ audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
+ audio_logits = audio_logits.contiguous().view(
+ -1, self.codec_lm.config.vocab_size
+ )
+ audio_labels = audio_labels.contiguous().view(-1)
+ audio_labels = audio_labels.masked_fill(
+ audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
+ )
+ codec_loss = self.loss_fct(audio_logits, audio_labels)
+ audio_preds = torch.argmax(audio_logits, -1)
+
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ acc = compute_accuracy(
+ preds.detach()[:, :-1],
+ labels.detach()[:, 1:],
+ ignore_label=IGNORE_TOKEN_ID,
+ )
+ audio_acc = compute_accuracy(
+ audio_preds.detach(),
+ audio_labels.detach(),
+ ignore_label=IGNORE_TOKEN_ID,
+ )
+ audio_topk_acc = self.audio_accuracy_metric(
+ audio_logits.detach(), audio_labels.detach()
+ ).item()
+
+ return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
+
+ def decode(
+ self,
+ fbank: torch.Tensor = None,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor = None,
+ **kwargs,
+ ):
+
+ encoder_outs = self.encoder(fbank)
+ speech_features = self.encoder_projector(encoder_outs)
+ speech_features = speech_features.to(torch.float16)
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
+ (
+ inputs_embeds,
+ attention_mask,
+ _,
+ _,
+ ) = self._merge_input_ids_with_speech_features(
+ speech_features, inputs_embeds, input_ids, attention_mask
+ )
+ generated_ids = self.llm.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
+ num_beams=kwargs.get("num_beams", 1),
+ do_sample=kwargs.get("do_sample", True),
+ min_length=kwargs.get("min_length", 1),
+ top_p=kwargs.get("top_p", 0.5),
+ top_k=kwargs.get("top_k", 20),
+ repetition_penalty=kwargs.get("repetition_penalty", 1.1),
+ temperature=kwargs.get("temperature", 0.7),
+ bos_token_id=self.llm.config.bos_token_id,
+ eos_token_id=self.llm.config.eos_token_id,
+ pad_token_id=self.llm.config.pad_token_id,
+ )
+
+ return generated_ids
+
+ def decode_with_speech_output(
+ self,
+ fbank: torch.Tensor = None,
+ input_ids: torch.LongTensor = None, # Prompt input_ids
+ attention_mask: torch.Tensor = None, # Prompt attention_mask
+ max_text_new_tokens: int = 1024,
+ max_speech_new_tokens: int = 2048, # Max length for speech tokens
+ llm_kwargs: dict = None, # Kwargs for text LLM generate
+ codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
+ ) -> Tuple[torch.LongTensor, List[List[int]]]:
+ """
+ Generates text and corresponding speech tokens using the revised logic.
+
+ Args:
+ fbank: Input audio features.
+ input_ids: Input token IDs for the text prompt.
+ attention_mask: Attention mask for the text prompt.
+ max_text_new_tokens: Max new tokens for text generation.
+ max_speech_new_tokens: Max new tokens for speech generation.
+ llm_kwargs: Additional arguments for self.llm.generate.
+ codec_lm_kwargs: Additional arguments for self.codec_lm.generate.
+
+ Returns:
+ Tuple[torch.LongTensor, List[List[int]]]:
+ - generated_text_ids: Tensor of generated text token IDs (including prompt).
+ - generated_speech_tokens: List of lists, where each inner list contains
+ the generated speech codec tokens for a batch item.
+ """
+ batch_size = input_ids.shape[0]
+ assert batch_size == 1, "Batch size must be 1 for speech generation."
+
+ device = next(self.parameters()).device # Use model's device
+
+ prompt_embeds = self.llm.get_input_embeddings()(input_ids)
+
+ # Merge speech features with prompt embeddings
+ if fbank is not None:
+ encoder_outs = self.encoder(fbank)
+ speech_features = self.encoder_projector(encoder_outs)
+ speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
+ (
+ merged_prompt_inputs_embeds,
+ merged_prompt_attention_mask,
+ _,
+ _,
+ ) = self._merge_input_ids_with_speech_features(
+ speech_features, prompt_embeds, input_ids, attention_mask
+ )
+ else:
+ merged_prompt_inputs_embeds = prompt_embeds
+ merged_prompt_attention_mask = attention_mask
+
+ # --- 2. Generate Text using LLM ---
+ # Use merged embeds/mask as input to generate
+ # Ensure kwargs passed are suitable for llm.generate
+ # Note: Using default generation params from `decode` if not provided in kwargs
+ final_llm_kwargs = {
+ "bos_token_id": self.llm.config.bos_token_id,
+ "eos_token_id": self.llm.config.eos_token_id,
+ "pad_token_id": self.llm.config.pad_token_id,
+ "num_beams": 1,
+ "do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
+ "top_p": 0.5,
+ "top_k": 20,
+ "repetition_penalty": 1.1,
+ "temperature": 0.7,
+ **(llm_kwargs or {}), # User-provided kwargs override defaults
+ }
+
+ text_outputs = self.llm.generate(
+ inputs_embeds=merged_prompt_inputs_embeds,
+ attention_mask=merged_prompt_attention_mask,
+ max_new_tokens=max_text_new_tokens,
+ return_dict_in_generate=True,
+ output_hidden_states=True,
+ **final_llm_kwargs,
+ )
+ delay_step = 1
+ generated_text_ids = text_outputs.sequences # [B, S_full]
+ eos_token_id = self.llm.config.eos_token_id
+ eos_token_embedding = self.llm.get_input_embeddings()(
+ torch.tensor([[eos_token_id]], device=device)
+ )
+ assert (
+ generated_text_ids[0, -1] == eos_token_id
+ ), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
+ thinker_token_embeds_org = [
+ token_hidden_states[0].to(self.llm.device)
+ for token_hidden_states in text_outputs.hidden_states
+ ]
+
+ first_thinker_token_embed = torch.cat(
+ [
+ thinker_token_embeds_org[0][:, 1:],
+ thinker_token_embeds_org[1],
+ ],
+ dim=1,
+ )
+
+ thinker_token_embeds = (
+ [first_thinker_token_embed]
+ + thinker_token_embeds_org[2:]
+ + [eos_token_embedding]
+ )
+ thinker_hidden_states = [
+ token_hidden_states[-1].to(self.llm.device)
+ for token_hidden_states in text_outputs.hidden_states
+ ]
+
+ thinker_reply_part = [
+ torch.cat(
+ [
+ thinker_hidden_state,
+ thinker_token_embed,
+ ],
+ dim=-1,
+ )
+ for thinker_hidden_state, thinker_token_embed in zip(
+ thinker_hidden_states[1:], thinker_token_embeds[1:]
+ )
+ ]
+ thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
+ # thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
+ thinker_prompt_part = torch.cat(
+ [
+ thinker_hidden_states[0],
+ thinker_token_embeds[0],
+ ],
+ dim=-1,
+ )
+
+ thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
+ thinker_reply_part = self.speech_token_projector(thinker_reply_part)
+
+ thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
+ talker_input_ids = torch.full(
+ (batch_size, thinker_prompt_part_seq_len + delay_step + 1),
+ self.codec_lm.config.mask_token_id,
+ dtype=torch.long,
+ device=self.llm.device,
+ )
+ talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
+ talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
+ thinker_input_embeds = torch.cat(
+ [
+ thinker_prompt_part,
+ thinker_reply_part[:, : delay_step + 1, :],
+ ],
+ dim=1,
+ )
+ talker_inputs_embeds += thinker_input_embeds
+ thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
+
+ past_key_values = None
+
+ generated_speech_tokens_list = []
+ next_token_ids = None
+
+ for t in range(max_speech_new_tokens):
+ if t > 0:
+ talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
+ next_token_ids
+ )
+ if thinker_reply_part.shape[1] > 0:
+ talker_inputs_embeds += thinker_reply_part[:, :1, :]
+ thinker_reply_part = thinker_reply_part[:, 1:, :]
+
+ codec_outputs = self.codec_lm(
+ inputs_embeds=talker_inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=True,
+ return_dict=True,
+ output_hidden_states=True,
+ )
+ last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
+ next_token_logits = self.codec_lm_head(last_token_hidden_state)
+
+ next_token_ids = topk_sampling(
+ next_token_logits,
+ )
+ if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
+ break
+
+ past_key_values = codec_outputs.past_key_values # Update KV cache
+ generated_speech_tokens_list.append(
+ next_token_ids.squeeze(1).cpu().tolist()[0]
+ )
+
+ return generated_text_ids, generated_speech_tokens_list
+
+
+def compute_accuracy(pad_outputs, pad_targets, ignore_label):
+ """Calculate accuracy.
+ Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
+ Args:
+ pad_outputs (LongTensor): Prediction tensors (B, Lmax).
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
+ ignore_label (int): Ignore label id.
+
+ Returns:
+ float: Accuracy value (0.0 - 1.0).
+
+ """
+ mask = pad_targets != ignore_label
+ numerator = torch.sum(
+ pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
+ )
+ denominator = torch.sum(mask)
+ return numerator.float() / denominator.float()
+
+
+def topk_sampling(
+ logits,
+ top_k=50,
+ top_p=0.95,
+ temperature=0.8,
+):
+ if temperature != 1.0:
+ logits = logits / temperature
+ # Top-p/top-k filtering
+ logits_filtered = top_k_top_p_filtering(
+ logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
+ )
+ # Sample
+ probs = torch.nn.functional.softmax(logits_filtered, dim=-1)
+ tokens = torch.multinomial(probs, num_samples=1)
+
+ return tokens
+
+
+# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
+def top_k_top_p_filtering(
+ logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1
+):
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
+ Args:
+ logits: logits distribution shape (batch size, vocabulary size)
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+ """
+ if top_k > 0:
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
+ )
+
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ if min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ 1, sorted_indices, sorted_indices_to_remove
+ )
+ logits[indices_to_remove] = filter_value
+ return logits
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements-cosyvoice.txt b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements-cosyvoice.txt
new file mode 100644
index 000000000..8962f76e3
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements-cosyvoice.txt
@@ -0,0 +1,23 @@
+conformer==0.3.2
+diffusers==0.29.0
+gdown==5.1.0
+gradio
+hydra-core==1.3.2
+HyperPyYAML==1.2.2
+inflect==7.3.1
+librosa==0.10.2
+lightning==2.2.4
+matplotlib==3.7.5
+#modelscope==1.15.0
+networkx==3.1
+omegaconf==2.3.0
+onnx==1.16.0
+onnxruntime-gpu==1.18.0
+protobuf==4.25
+pydantic==2.7.0
+pyworld==0.3.4
+rich==13.7.1
+soundfile==0.12.1
+tensorboard==2.14.0
+wget==3.2
+WeTextProcessing==1.0.3
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt
new file mode 100644
index 000000000..ce14647fc
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt
@@ -0,0 +1,15 @@
+openai-whisper
+kaldialign
+lhotse
+sentencepiece
+pypinyin
+tensorboard
+librosa
+deepspeed
+transformers>=4.37.0
+flash-attn
+peft
+torchmetrics
+# triton==3.3.0 # may be violate with openai-whisper
+gradio
+sherpa-onnx
\ No newline at end of file
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py
new file mode 100644
index 000000000..f0da7f905
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py
@@ -0,0 +1,131 @@
+# server.py
+import argparse
+import os
+from typing import List
+
+import torch
+import uvicorn
+import whisper
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
+from transformers import AutoTokenizer
+from web_demo import get_model
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="extract speech code")
+ parser.add_argument(
+ "--checkpoint-path",
+ type=str,
+ default=None,
+ help="Checkpoint name or path, default to %(default)r",
+ )
+ parser.add_argument(
+ "--prompt-template",
+ type=str,
+ default=None,
+ help="Prompt template",
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8001,
+ help="Port number",
+ )
+ add_model_arguments(parser)
+ args = parser.parse_args()
+ return args
+
+
+class SpeechRequest(BaseModel):
+ audio: List[float] # Expecting audio as a list of floats (raw waveform)
+ sampling_rate: int = 16000
+
+
+class TextResponse(BaseModel):
+ text: str
+
+
+def preprocess_prompt(tokenizer):
+ """Preprocesses the prompt template."""
+ texts = [
+ tokenizer.apply_chat_template(
+ message, # Using the hardcoded message
+ tokenize=True,
+ add_generation_prompt=False, # Important for generation
+ chat_template=TEMPLATE,
+ padding=False, # No padding needed for single prompt
+ truncation=False,
+ )
+ ]
+ input_ids = torch.tensor(texts, dtype=torch.long)
+ attention_mask = torch.ones_like(
+ input_ids, dtype=torch.bool
+ ) # Mask is all True for the prompt
+ return input_ids, attention_mask
+
+
+args = get_args()
+print(f"Using port: {args.port}")
+model, tokenizer = get_model(args)
+app = FastAPI()
+
+device = torch.device("cuda")
+if args.prompt_template is None:
+ template = f"{DEFAULT_SPEECH_TOKEN}"
+elif args.prompt_template == "qa":
+ template = f"Answer the following question:\n\n{DEFAULT_SPEECH_TOKEN}"
+elif args.prompt_template == "continuation":
+ template = f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}"
+elif args.prompt_template == "asr":
+ template = (
+ f"Repeat the following text, without any explanation: {DEFAULT_SPEECH_TOKEN}"
+ )
+elif args.prompt_template == "mt":
+ template = f"Please translate the text to Chinese. Your response should only include the Chinese translation, without any additional words:\n\n{DEFAULT_SPEECH_TOKEN}"
+else:
+ raise ValueError(f"Invalid prompt template: {args.prompt_template}")
+print("Using template:", template)
+message = [
+ {"role": "user", "content": template},
+ {"role": "assistant", "content": ""},
+]
+TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+prompt_input_ids, prompt_attention_mask = preprocess_prompt(tokenizer)
+prompt_input_ids = prompt_input_ids.to(device)
+prompt_attention_mask = prompt_attention_mask.to(device)
+
+
+@app.post("/decode", response_model=TextResponse)
+async def decode_speech(request: SpeechRequest):
+ """
+ Receives audio waveform, processes it, and returns the decoded text.
+ """
+ if request.sampling_rate != 16000:
+ raise HTTPException(
+ status_code=400, detail="Only 16kHz sampling rate is supported."
+ )
+
+ try:
+ audio_tensor = torch.tensor(request.audio, dtype=torch.float32).to(device)
+ fbank = whisper.log_mel_spectrogram(audio_tensor, device=device, n_mels=80)
+ fbank = fbank.unsqueeze(0)
+
+ with torch.no_grad():
+ generated_ids = model.decode(fbank, prompt_input_ids, prompt_attention_mask)
+
+ hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+
+ response_text = hyps[0] if hyps else ""
+
+ return TextResponse(text=response_text)
+
+ except Exception as e:
+ print(f"Error during processing: {e}")
+ raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
+
+
+if __name__ == "__main__":
+ print("Starting server...")
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py
new file mode 100755
index 000000000..5b5628f74
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py
@@ -0,0 +1,1160 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
+# Qwen Pretrained model
+huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
+
+torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 50 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --manifest-dir data/fbank \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
+"""
+
+import argparse
+import copy
+import logging
+import os
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import deepspeed
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import transformers
+import whisper
+from data_module import AsrDataModule
+from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
+from label_smoothing import LabelSmoothingLoss
+from lhotse import CutSet, load_manifest
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
+from peft import LoraConfig, get_peft_model
+from torch import Tensor
+from torch.utils.tensorboard import SummaryWriter
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ Qwen2Config,
+ Qwen2ForCausalLM,
+)
+from utils import ( # filter_uneven_sized_batch,
+ AttributeDict,
+ MetricsTracker,
+ get_local_rank,
+ get_rank,
+ get_world_size,
+ setup_logger,
+ str2bool,
+)
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
+DEFAULT_SPEECH_TOKEN = ""
+try:
+ torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+ pass
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--remove-whisper-encoder-input-length-restriction",
+ type=str2bool,
+ default=True,
+ help="replace whisper encoder forward method to remove input length restriction",
+ )
+ parser.add_argument(
+ "--llm-path-or-name",
+ type=str,
+ default="/workspace/asr/Qwen1.5-0.5B-Chat",
+ help="Path or name of the large language model.",
+ )
+
+ parser.add_argument(
+ "--speech-encoder-path-or-name",
+ type=str,
+ default="whisper-large-v2",
+ help="Path or name of the speech encoder.",
+ )
+
+ parser.add_argument(
+ "--encoder-projector-ds-rate",
+ type=int,
+ default=8,
+ help="Downsample rate for the encoder projector.",
+ )
+ parser.add_argument(
+ "--use-flash-attn",
+ type=str2bool,
+ default=True,
+ help="Whether to use flash attention.",
+ )
+
+ parser.add_argument(
+ "--use-lora",
+ type=str2bool,
+ default=False,
+ help="Whether to use lora to fine-tune llm.",
+ )
+
+ parser.add_argument(
+ "--enable-speech-output",
+ type=str2bool,
+ default=False,
+ help="Whether to enable speech codec output.",
+ )
+
+ parser.add_argument(
+ "--enable-speech-input",
+ type=str2bool,
+ default=True,
+ help="Whether to enable speech fbank input.",
+ )
+
+ parser.add_argument(
+ "--speech-tokenizer-type",
+ type=str,
+ default="cosyvoice2",
+ help="The type of the speech tokenizer. cosyvoice2: 6561, cosyvoice1: 4096",
+ )
+
+
+def add_training_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=10,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="whisper_qwen/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--pretrained-model-path",
+ type=str,
+ default=None,
+ help="""The path to the pretrained model if it is not None. Training will
+ start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--last-stage-model-path",
+ type=str,
+ default=None,
+ help="""The path to the last stage model if it is not None. Training will start from this model.
+ """,
+ )
+ parser.add_argument(
+ "--sampler-state-dict-path",
+ type=str,
+ default=None,
+ help="""The path to the sampler state dict if it is not None. Training will start from this sampler state dict.
+ """,
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--unfreeze-llm",
+ type=str2bool,
+ default=False,
+ help="Whether to unfreeze llm during training.",
+ )
+
+ parser.add_argument(
+ "--unfreeze-speech-projector",
+ type=str2bool,
+ default=False,
+ help="Whether to unfreeze speech adaptor during training.",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="multi_en",
+ help="The name of the dataset.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--loss-type",
+ type=str,
+ default="ce",
+ help="The type of loss to use.",
+ )
+
+ parser = deepspeed.add_config_arguments(parser)
+ add_model_arguments(parser)
+ add_training_arguments(parser)
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - frame_shift_ms: The frame shift in milliseconds.
+ - allowed_excess_duration_ratio: The allowed excess duration ratio.
+ - best_train_loss: The best training loss so far.
+ - best_valid_loss: The best validation loss so far.
+ - best_train_epoch: The epoch where the best training loss is achieved.
+ - best_valid_epoch: The epoch where the best validation loss is achieved.
+ - batch_idx_train: The batch index of the current batch.
+ - log_interval: Log training stats every `log_interval` batches.
+ - reset_interval: Reset the stats every `reset_interval` batches.
+ - valid_interval: Run validation every `valid_interval` batches.
+ - env_info: The environment information.
+ """
+ params = AttributeDict(
+ {
+ "allowed_excess_duration_ratio": 0.1,
+ "subsampling_factor": 2,
+ "frame_shift_ms": 10,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 1000,
+ }
+ )
+
+ return params
+
+
+def extract_text_and_speech_token(
+ batch: dict, enable_speech_output: bool
+) -> Tuple[List[Dict[str, str]], Optional[List[Any]]]:
+ """
+ Extracts messages and speech tokens from a batch based on the dataset format.
+ Uses the global DEFAULT_SPEECH_TOKEN.
+ """
+ messages = []
+ speech_tokens = None # Initialize as None
+ if enable_speech_output:
+ if "answer_cosyvoice_speech_token" in batch["supervisions"]["cut"][0].custom:
+ assert "speech_token" not in batch["supervisions"]["cut"][0].custom
+ speech_tokens = [
+ cut.custom["answer_cosyvoice_speech_token"]
+ for cut in batch["supervisions"]["cut"]
+ ]
+ elif "speech_token" in batch["supervisions"]["cut"][0].custom:
+ speech_tokens = [
+ cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
+ ]
+ else:
+ raise ValueError("Unknown speech token type")
+ answers = batch["supervisions"]["text"]
+ batch_size = len(answers)
+
+ prompt_template_dict = {
+ "speech_qa": f"{DEFAULT_SPEECH_TOKEN}",
+ "speech_continuation": f"Continue the following text using less than 50 words:\\n\\n{DEFAULT_SPEECH_TOKEN}",
+ "asr": f"Transcribe the following audio into text:\\n\\n{DEFAULT_SPEECH_TOKEN}",
+ }
+
+ for i in range(batch_size):
+ # Initialize prompt_template with the original default.
+ # The 'prompt_template' argument to the function seems unused if we determine it here.
+ # For now, I will proceed assuming the internal logic dictates the template.
+ # If the function argument `prompt_template` was meant to be the default, this logic would need adjustment.
+ current_prompt_template = (
+ "speech_qa" # Default value for prompt_template for the current item
+ )
+ target = answers[i]
+ message_list_item = []
+
+ custom_data = batch["supervisions"]["cut"][i].custom
+
+ if "round" in custom_data:
+ # slam_omni format dataset
+ # For 'round' type, the current interaction's user prompt will use current_prompt_template ("speech_qa")
+ current_question_with_history = custom_data["question"]
+ total_round = custom_data["round"]
+ history_context = current_question_with_history.rsplit(":", 1)[
+ 0
+ ].strip()
+ if total_round > 1:
+ history_question_answer = history_context.split("USER:")
+ history_question_answer = [
+ item for item in history_question_answer if item
+ ]
+ for j in range(total_round - 1):
+ question_answer = history_question_answer[j].split("ASSISTANT:")
+ message_list_item += [
+ {"role": "user", "content": question_answer[0].strip()},
+ {"role": "assistant", "content": question_answer[1].strip()},
+ ]
+ elif "continuation" in custom_data:
+ # see https://huggingface.co/datasets/fixie-ai/librispeech_asr
+ ASR_PROBABILITY = 0.3
+ if random.random() < ASR_PROBABILITY:
+ current_prompt_template = "asr"
+ else:
+ current_prompt_template = "speech_continuation"
+ target = custom_data["continuation"]
+ else:
+ # single-round, speech2speech conversation data
+ pass
+ message_list_item += [
+ {"role": "user", "content": prompt_template_dict[current_prompt_template]},
+ {"role": "assistant", "content": target},
+ ]
+ messages.append(message_list_item)
+
+ return messages, speech_tokens
+
+
+def preprocess(
+ messages,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """Preprocesses the data for supervised fine-tuning."""
+ texts = []
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ for i, msg in enumerate(messages):
+ texts.append(
+ tokenizer.apply_chat_template(
+ msg,
+ tokenize=True,
+ chat_template=TEMPLATE,
+ add_generation_prompt=False,
+ padding="longest", # FIX me change padding to longest
+ truncation=False,
+ )
+ )
+ if len(texts) != len(messages):
+ logging.warning(f"Remove too long text, {messages} ")
+ max_len_texts = max([len(text) for text in texts])
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
+ else:
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
+ input_ids = torch.tensor(texts, dtype=torch.int)
+
+ target_ids = input_ids.clone()
+ target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
+ # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
+ # first get the indices of the tokens
+ mask_prompt = True
+ if mask_prompt:
+ default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
+ mask_indices = torch.where(input_ids == default_speech_token_id)
+ for i in range(mask_indices[0].size(0)):
+ row = mask_indices[0][i]
+ col = mask_indices[1][i]
+ # + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
+ # WAR: TODO FIXME check qwen3
+ target_ids[row, : col + 6] = IGNORE_TOKEN_ID
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+ return input_ids, attention_mask, target_ids
+
+
+def process_batch_text_continuation(batch: dict):
+ messages = []
+ transcripts = batch["supervisions"]["text"]
+ continuations = [cut.custom["continuation"] for cut in batch["supervisions"]["cut"]]
+ for i in range(len(transcripts)):
+ message = [
+ {
+ "role": "user",
+ "content": f"Continue the following text using less than 50 words:\n\n{transcripts[i]}{DEFAULT_SPEECH_TOKEN}",
+ },
+ {"role": "assistant", "content": continuations[i]},
+ ]
+ messages.append(message)
+ return messages
+
+
+def preprocess_teacher(
+ messages,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """Preprocesses the data for supervised fine-tuning."""
+ texts = []
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ for i, msg in enumerate(messages):
+ texts.append(
+ tokenizer.apply_chat_template(
+ msg,
+ tokenize=True,
+ chat_template=TEMPLATE,
+ add_generation_prompt=False,
+ padding="longest", # FIX me change padding to longest
+ truncation=False,
+ )
+ )
+ if len(texts) != len(messages):
+ logging.warning(f"Remove too long text, {messages} ")
+ max_len_texts = max([len(text) for text in texts])
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
+ else:
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
+ input_ids = torch.tensor(texts, dtype=torch.int)
+
+ target_ids = input_ids.clone()
+ target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
+ # mask all tokens before token_id with IGNORE_TOKEN_ID
+ # first get the indices of the tokens
+ mask_prompt = True
+ if mask_prompt:
+ default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
+ mask_indices = torch.where(input_ids == default_speech_token_id)
+ for i in range(mask_indices[0].size(0)):
+ row = mask_indices[0][i]
+ col = mask_indices[1][i]
+ # + 2 to skip: 'assistant', '\n'
+ # WAR: TODO FIXME check qwen3
+ # THIS IS THE ONLY DIFFERENCE FROM preprocess
+ target_ids[row, : col + 6] = IGNORE_TOKEN_ID
+ target_ids[row, col] = default_speech_token_id
+ # remove default_speech_token_id from target_ids and input_ids
+ batch_size = target_ids.size(0)
+
+ target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
+ input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
+
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+ return input_ids, attention_mask, target_ids
+
+
+def compute_loss(
+ params: AttributeDict,
+ tokenizer: AutoTokenizer,
+ model: nn.Module,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute the loss for the given batch.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ tokenizer:
+ The tokenizer used to encode the text.
+ model:
+ The model for training.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ Whether it is training.
+ Returns:
+ Return a tuple of two elements. The first element is the loss tensor.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+
+ assert feature.ndim == 3
+ feature = feature.to(device)
+ feature = feature.transpose(1, 2) # (N, C, T)
+
+ messages, answer_cosyvoice_speech_token = extract_text_and_speech_token(
+ batch, params.enable_speech_output
+ )
+
+ input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
+
+ target_ids = target_ids.type(torch.LongTensor)
+ input_ids = input_ids.type(torch.LongTensor)
+
+ with torch.set_grad_enabled(is_training):
+ if not params.enable_speech_output:
+ if params.loss_type == "ce":
+ loss, acc = model(
+ fbank=feature,
+ input_ids=input_ids.to(device),
+ attention_mask=attention_mask.to(device),
+ labels=target_ids.to(device),
+ )
+ elif params.loss_type == "kl_div":
+ messages_text = process_batch_text_continuation(batch)
+ (
+ teacher_input_ids,
+ teacher_attention_mask,
+ teacher_target_ids,
+ ) = preprocess_teacher(messages_text, tokenizer)
+ loss, acc, acc_teacher = model.forward_kl_div(
+ fbank=feature,
+ input_ids=input_ids.to(device),
+ attention_mask=attention_mask.to(device),
+ labels=target_ids.to(device),
+ teacher_input_ids=teacher_input_ids.to(device),
+ teacher_attention_mask=teacher_attention_mask.to(device),
+ teacher_labels=teacher_target_ids.to(device),
+ )
+ else:
+ raise ValueError(f"Unknown loss type: {params.loss_type}")
+ else:
+ assert params.loss_type == "ce"
+ (
+ text_loss,
+ acc,
+ codec_loss,
+ codec_acc,
+ codec_topk_acc,
+ ) = model.forward_with_speech_output(
+ fbank=feature,
+ input_ids=input_ids.to(device),
+ attention_mask=attention_mask.to(device),
+ labels=target_ids.to(device),
+ speech_codec_ids=answer_cosyvoice_speech_token,
+ )
+ loss = text_loss + codec_loss
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ feature_lens = batch["supervisions"]["num_frames"]
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["acc"] = (
+ acc * info["frames"]
+ ) # WAR: to avoid normalization by the number of frames
+ if params.loss_type == "kl_div":
+ info["acc_teacher"] = acc_teacher * info["frames"]
+ if params.enable_speech_output:
+ info["codec_acc"] = codec_acc * info["frames"]
+ info["codec_topk_acc"] = codec_topk_acc * info["frames"]
+ info["codec_loss"] = codec_loss.detach().cpu().item()
+ info["text_loss"] = text_loss.detach().cpu().item()
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: nn.Module,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ with torch.amp.autocast("cuda", enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ tokenizer: AutoTokenizer,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ if params.enable_speech_input:
+ model.encoder.eval()
+ if not params.unfreeze_llm:
+ model.llm.eval()
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+ if batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ if params.enable_speech_input:
+ model.encoder.eval()
+ if not params.unfreeze_llm:
+ model.llm.eval()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+ if batch_idx != 0:
+ model.save_checkpoint(
+ save_dir=params.exp_dir,
+ tag=f"zero-checkpoint-{params.batch_idx_train}",
+ client_state={},
+ exclude_frozen_parameters=True,
+ )
+
+ if rank == 0:
+ convert_zero_checkpoint_to_fp32_state_dict(
+ params.exp_dir,
+ f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
+ tag=f"zero-checkpoint-{params.batch_idx_train}",
+ exclude_frozen_parameters=True,
+ )
+ # save sampler state dict into checkpoint
+ sampler_state_dict = train_dl.sampler.state_dict()
+ torch.save(
+ sampler_state_dict,
+ f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
+ )
+ os.system(
+ f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
+ )
+ try:
+ with torch.amp.autocast("cuda", enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+
+ # deepspeed's backward() is different from torch's backward()
+ # in that it does not accept a loss tensor as input.
+ # It computes the loss internally.
+ model.backward(loss)
+ model.step()
+
+ except: # noqa
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if batch_idx % params.log_interval == 0:
+ try:
+ cur_lr = scheduler.get_last_lr()[0]
+ except: # noqa
+ cur_lr = 0.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def get_model(params):
+ """Load and prepare the speech-to-speech model."""
+ tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
+ tokenizer.add_special_tokens(special_tokens_dict)
+
+ if params.use_flash_attn:
+ attn_implementation = "flash_attention_2"
+ torch_dtype = torch.float16
+ tokenizer.padding_side = "left"
+
+ else:
+ attn_implementation = "eager"
+ torch_dtype = torch.float16
+ tokenizer.padding_side = "right"
+
+ llm = AutoModelForCausalLM.from_pretrained(
+ params.llm_path_or_name,
+ attn_implementation=attn_implementation,
+ torch_dtype=torch_dtype,
+ )
+ if not params.unfreeze_llm:
+ for name, param in llm.named_parameters():
+ param.requires_grad = False
+ if params.use_lora:
+ lora_config = LoraConfig(
+ r=64,
+ lora_alpha=16,
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "up_proj",
+ "gate_proj",
+ "down_proj",
+ ],
+ lora_dropout=0.05,
+ task_type="CAUSAL_LM",
+ )
+ llm = get_peft_model(llm, lora_config)
+ llm.print_trainable_parameters()
+
+ llm.config.pad_token_id = tokenizer.pad_token_id
+ llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
+ llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
+ llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+ llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
+ DEFAULT_SPEECH_TOKEN
+ )
+
+ if params.enable_speech_input:
+ if params.remove_whisper_encoder_input_length_restriction:
+ replace_whisper_encoder_forward()
+ whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
+ speech_encoder = whisper_model.encoder
+ speech_encoder_dim = whisper_model.dims.n_audio_state
+ for name, param in speech_encoder.named_parameters():
+ param.requires_grad = False
+ encoder_projector = EncoderProjector(
+ speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
+ )
+ if not params.unfreeze_speech_projector:
+ for name, param in encoder_projector.named_parameters():
+ param.requires_grad = False
+ encoder_projector.eval()
+ else:
+ speech_encoder = None
+ encoder_projector = None
+
+ if params.enable_speech_output:
+ # Determine attn_implementation and torch_dtype based on use_flash_attn
+ if params.use_flash_attn:
+ attn_implementation = "flash_attention_2"
+ torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
+ else:
+ attn_implementation = "eager"
+ torch_dtype = torch.float16
+ if params.speech_tokenizer_type == "cosyvoice2":
+ codec_vocab_size = 6561 + 4
+ elif params.speech_tokenizer_type == "cosyvoice1":
+ codec_vocab_size = 4096 + 4
+ else:
+ raise ValueError(
+ f"Unknown speech tokenizer type: {params.speech_tokenizer_type}"
+ )
+
+ config = Qwen2Config(
+ vocab_size=codec_vocab_size,
+ hidden_size=1024,
+ num_hidden_layers=12,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ intermediate_size=2048,
+ max_position_embeddings=4096,
+ )
+
+ codec_lm = AutoModelForCausalLM.from_config(
+ config=config,
+ attn_implementation=attn_implementation,
+ torch_dtype=torch_dtype,
+ )
+
+ codec_lm.resize_token_embeddings(codec_vocab_size)
+ codec_lm.vocab_size = codec_vocab_size
+ codec_lm.config.pad_token_id = codec_vocab_size - 1
+ codec_lm.config.eos_token_id = codec_vocab_size - 2
+ codec_lm.config.bos_token_id = codec_vocab_size - 3
+ codec_lm.config.mask_token_id = codec_vocab_size - 4
+ else:
+ codec_lm = None
+
+ model = SPEECH_LLM(
+ speech_encoder,
+ llm,
+ encoder_projector,
+ codec_lm,
+ codec_lm_padding_side="left" if params.use_flash_attn else "right",
+ )
+ if params.pretrained_model_path or params.last_stage_model_path:
+ if params.pretrained_model_path is None:
+ checkpoint = torch.load(params.last_stage_model_path, map_location="cpu")
+ missing_keys, unexpected_keys = model.load_state_dict(
+ checkpoint, strict=False
+ )
+ else:
+ checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
+ missing_keys, unexpected_keys = model.load_state_dict(
+ checkpoint, strict=False
+ )
+ # set params.batch_idx_train according to the checkpoint name
+ if "checkpoint-" in params.pretrained_model_path:
+ params.batch_idx_train = int(
+ params.pretrained_model_path.split("-")[-1].split("/")[0]
+ )
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ logging.info("Trainable parameters (excluding model.eval modules):")
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ logging.info(f"{name}: {param.shape}")
+
+ return model, tokenizer
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+
+ if rank == 0:
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info(params)
+ logging.info("About to create model")
+
+ model, tokenizer = get_model(params)
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda", get_local_rank())
+ else:
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+ model.to(device)
+
+ assert params.deepspeed and world_size > 1
+ logging.info("Using DeepSpeed")
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=params, model=model, model_parameters=model.parameters()
+ )
+
+ data_module = AsrDataModule(args)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 0.8 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+ if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
+ codec_len = (
+ len(c.custom["answer_cosyvoice_speech_token"])
+ if "answer_cosyvoice_speech_token" in c.custom
+ else len(c.custom["speech_token"])
+ )
+ if codec_len > 2200:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
+ )
+ return False
+ if "question" in c.custom:
+ if len(c.custom["question"]) > 1200:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. question length: {len(c.custom['question'])}"
+ # )
+ return False
+ return True
+
+ if params.dataset == "slam_omni_belle":
+ train_cuts = data_module.train_cuts_belle()
+ valid_cuts = data_module.dev_cuts_belle()
+ elif params.dataset == "vocalnet_ultrachat_voiceassistant":
+ train_cuts = data_module.train_cuts_en_vocalnet()
+ valid_cuts = data_module.valid_cuts_en_vocalnet()
+ elif params.dataset == "vocalnet_ultrachat_voiceassistant_instruct_s2s":
+ train_cuts = data_module.train_cuts_en_speech2speech()
+ valid_cuts = data_module.valid_cuts_en_vocalnet()
+ elif params.dataset == "vocalnet_ultrachat_voiceassistant_instruct_s2s_librispeech":
+ train_cuts = data_module.train_cuts_en_speech2speech_librispeech()
+ valid_cuts = data_module.valid_cuts_en_vocalnet()
+ elif params.dataset == "ultravox_multi_en":
+ train_cuts = data_module.train_cuts_ultravox()
+ valid_cuts = data_module.valid_cuts_ultravox()
+ elif params.dataset == "librispeech":
+ train_cuts = data_module.train_cuts_librispeech()
+ valid_cuts = data_module.valid_cuts_ultravox()
+ elif params.dataset == "gigaspeech":
+ train_cuts = data_module.train_cuts_gigaspeech()
+ valid_cuts = data_module.valid_cuts_ultravox()
+ else:
+ raise ValueError(f"Unknown dataset: {params.dataset}")
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+ valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
+
+ sampler_state_dict = None
+ if params.sampler_state_dict_path:
+ sampler_state_dict = torch.load(params.sampler_state_dict_path)
+ sampler_state_dict["max_duration"] = params.max_duration
+
+ train_dl = data_module.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+ valid_dl = data_module.valid_dataloaders(valid_cuts)
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ logging.info(f"start training from epoch {params.start_epoch}")
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ model.save_checkpoint(
+ save_dir=params.exp_dir,
+ tag=f"zero-epoch-{params.cur_epoch}",
+ client_state={},
+ exclude_frozen_parameters=True,
+ )
+ if rank == 0:
+ convert_zero_checkpoint_to_fp32_state_dict(
+ params.exp_dir,
+ f"{params.exp_dir}/epoch-{params.cur_epoch}",
+ tag=f"zero-epoch-{params.cur_epoch}",
+ exclude_frozen_parameters=True,
+ )
+ # save sampler state dict into checkpoint
+ sampler_state_dict = train_dl.sampler.state_dict()
+ torch.save(
+ sampler_state_dict,
+ f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
+ )
+
+ os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
+
+ logging.info("Done!")
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ warnings.filterwarnings("ignore", category=FutureWarning)
+ run(rank=rank, world_size=world_size, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py
new file mode 100755
index 000000000..e505c0700
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py
@@ -0,0 +1,604 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
+# Qwen Pretrained model
+huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
+
+torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
+ --max-duration 50 \
+ --enable-musan False \
+ --exp-dir $exp_dir \
+ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
+ --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
+ --manifest-dir data/fbank \
+ --deepspeed \
+ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
+"""
+
+import argparse
+import copy
+import logging
+import os
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import deepspeed
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import transformers
+from datasets import load_dataset
+
+from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
+from label_smoothing import LabelSmoothingLoss
+
+from lhotse.utils import fix_random_seed
+from model import IGNORE_TOKEN_ID, SPEECH_LLM
+from peft import LoraConfig, get_peft_model
+from torch import Tensor
+from torch.utils.tensorboard import SummaryWriter
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ Qwen2Config,
+ Qwen2ForCausalLM,
+)
+from torchdata.stateful_dataloader import StatefulDataLoader
+from torch.utils.data import DistributedSampler, DataLoader
+from pathlib import Path
+
+from train import add_model_arguments, add_training_arguments, get_params, get_model
+from utils import ( # filter_uneven_sized_batch,
+ AttributeDict,
+ MetricsTracker,
+ get_local_rank,
+ get_rank,
+ get_world_size,
+ setup_logger,
+ str2bool,
+)
+
+DEFAULT_SPEECH_TOKEN = ""
+try:
+ torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+ pass
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=16,
+ help="The batch size to use.",
+ )
+
+ parser = deepspeed.add_config_arguments(parser)
+ add_model_arguments(parser)
+ add_training_arguments(parser)
+ return parser
+
+def preprocess(
+ messages,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """Preprocesses the data for supervised fine-tuning."""
+ texts = []
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ for i, msg in enumerate(messages):
+ texts.append(
+ tokenizer.apply_chat_template(
+ msg,
+ tokenize=True,
+ chat_template=TEMPLATE,
+ add_generation_prompt=False,
+ padding="longest", # FIX me change padding to longest
+ truncation=False,
+ )
+ )
+ if len(texts) != len(messages):
+ logging.warning(f"Remove too long text, {messages} ")
+ max_len_texts = max([len(text) for text in texts])
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
+ else:
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
+ input_ids = torch.tensor(texts, dtype=torch.int)
+
+ target_ids = input_ids.clone()
+ target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
+ # mask all tokens before token_id with IGNORE_TOKEN_ID
+ # first get the indices of the tokens
+ mask_prompt = True
+ if mask_prompt:
+ default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
+ mask_indices = torch.where(input_ids == default_speech_token_id)
+ for i in range(mask_indices[0].size(0)):
+ row = mask_indices[0][i]
+ col = mask_indices[1][i]
+ # + 2 to skip: 'assistant', '\n'
+ # WAR: TODO FIXME check qwen3
+ # THIS IS THE ONLY DIFFERENCE FROM preprocess
+ target_ids[row, : col + 6] = IGNORE_TOKEN_ID
+ target_ids[row, col] = default_speech_token_id
+ # remove default_speech_token_id from target_ids and input_ids
+ batch_size = target_ids.size(0)
+
+ target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
+ input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
+
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+ return input_ids, attention_mask, target_ids
+
+def data_collator(batch):
+ speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
+ for i, item in enumerate(batch):
+ speech_tokens.append(item["code"])
+ message_list_item = []
+ message_list_item += [
+ {"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
+ {"role": "assistant", "content": item["text"]},
+ ]
+ # message_list_item += [
+ # {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"},
+ # {"role": "assistant", "content": item["text"]},
+ # ]
+ messages.append(message_list_item)
+ durations.append(item["duration"])
+ ids.append(item["index"] if "index" in item else item["id"])
+ lang.append(item["language"])
+
+ return {
+ "speech_tokens": speech_tokens,
+ "messages": messages,
+ "durations": durations,
+ "ids": ids,
+ "lang": lang,
+ }
+
+def data_collator_ultra_chat(batch):
+ speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
+ for i, item in enumerate(batch):
+ speech_tokens.append(item["custom"]["speech_token"])
+ text = item["supervisions"][0]["text"]
+ message_list_item = []
+ message_list_item += [
+ {"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"},
+ {"role": "assistant", "content": text},
+ ]
+ messages.append(message_list_item)
+ durations.append(item["duration"])
+ ids.append(item["id"])
+
+ return {
+ "speech_tokens": speech_tokens,
+ "messages": messages,
+ "durations": durations,
+ "ids": ids,
+ }
+
+def compute_loss(
+ params: AttributeDict,
+ tokenizer: AutoTokenizer,
+ model: nn.Module,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute the loss for the given batch.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ tokenizer:
+ The tokenizer used to encode the text.
+ model:
+ The model for training.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ Whether it is training.
+ Returns:
+ Return a tuple of two elements. The first element is the loss tensor.
+ """
+ device = next(model.parameters()).device
+ messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"]
+ input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
+ target_ids = target_ids.type(torch.LongTensor)
+ input_ids = input_ids.type(torch.LongTensor)
+
+ with torch.set_grad_enabled(is_training):
+ (
+ text_loss,
+ acc,
+ codec_loss,
+ codec_acc,
+ codec_topk_acc,
+ ) = model.forward_with_speech_output(
+ input_ids=input_ids.to(device),
+ attention_mask=attention_mask.to(device),
+ labels=target_ids.to(device),
+ speech_codec_ids=answer_cosyvoice_speech_token,
+ )
+ loss = text_loss + codec_loss
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ info["frames"] = len(messages)
+ # Note: We use reduction=sum while computing the loss.
+ info["acc"] = acc * len(messages)
+ info["codec_acc"] = codec_acc * len(messages)
+ info["codec_topk_acc"] = codec_topk_acc * len(messages)
+ info["loss"] = loss.detach().cpu().item()
+ info["codec_loss"] = codec_loss.detach().cpu().item()
+ info["text_loss"] = text_loss.detach().cpu().item()
+ return loss, info
+
+def compute_validation_loss(
+ params: AttributeDict,
+ tokenizer: AutoTokenizer,
+ model: nn.Module,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ with torch.amp.autocast("cuda", enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ # FIX ME
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+def train_one_epoch(
+ params: AttributeDict,
+ tokenizer: AutoTokenizer,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ # model.encoder.eval()
+ if not params.unfreeze_llm:
+ model.llm.eval()
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["durations"])
+ if batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ # model.encoder.eval()
+ if not params.unfreeze_llm:
+ model.llm.eval()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+ if batch_idx != 0:
+ model.save_checkpoint(
+ save_dir=params.exp_dir,
+ tag=f"zero-checkpoint-{params.batch_idx_train}",
+ client_state={},
+ exclude_frozen_parameters=True,
+ )
+
+ if rank == 0:
+ convert_zero_checkpoint_to_fp32_state_dict(
+ params.exp_dir,
+ f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
+ tag=f"zero-checkpoint-{params.batch_idx_train}",
+ exclude_frozen_parameters=True,
+ )
+ # save sampler state dict into checkpoint
+ # sampler_state_dict = train_dl.sampler.state_dict()
+ sampler_state_dict = train_dl.state_dict()
+ torch.save(
+ sampler_state_dict,
+ f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
+ )
+ os.system(
+ f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
+ )
+ try:
+ with torch.amp.autocast("cuda", enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+
+ # deepspeed's backward() is different from torch's backward()
+ # in that it does not accept a loss tensor as input.
+ # It computes the loss internally.
+ model.backward(loss)
+ model.step()
+
+ except: # noqa
+ raise
+
+ if batch_idx % params.log_interval == 0:
+ try:
+ cur_lr = scheduler.get_last_lr()[0]
+ except: # noqa
+ cur_lr = 0.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+ loss_value = tot_loss["loss"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ params.valid_interval = 2000
+
+ fix_random_seed(params.seed)
+
+ if rank == 0:
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info(params)
+ logging.info("About to create model")
+ model, tokenizer = get_model(params)
+ if torch.cuda.is_available():
+ device = torch.device("cuda", get_local_rank())
+ else:
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+ model.to(device)
+
+ # assert params.deepspeed and world_size > 1
+ logging.info("Using DeepSpeed")
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=params, model=model, model_parameters=model.parameters()
+ )
+
+ sampler_state_dict = None
+ if params.sampler_state_dict_path:
+ sampler_state_dict = torch.load(params.sampler_state_dict_path)
+ if params.dataset == "ultra_chat_voice_assistant":
+ data_dir = "data/fbank"
+ json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"]
+ ds = load_dataset("json", data_files=json_file_lists, split="train")
+ # shuffle the dataset
+ train_dataset = ds.shuffle(seed=42)
+ eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train")
+ else:
+ data_dir = Path(params.dataset)
+ json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")]
+ ds = load_dataset("json", data_files=json_file_lists, split="train")
+ # shuffle the dataset
+ ds = ds.shuffle(seed=42)
+ train_test_split = ds.train_test_split(test_size=1000, seed=42)
+ train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
+
+ sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
+ train_dl = StatefulDataLoader(
+ train_dataset,
+ batch_size=params.batch_size,
+ sampler=sampler,
+ shuffle=False,
+ num_workers=4,
+ prefetch_factor=2,
+ collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
+ )
+ train_dl.load_state_dict(sampler_state_dict)
+ valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
+ valid_dl = DataLoader(
+ eval_dataset,
+ batch_size=params.batch_size,
+ sampler=valid_sampler,
+ shuffle=False,
+ num_workers=1,
+ prefetch_factor=1,
+ collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
+ )
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ logging.info(f"start training from epoch {params.start_epoch}")
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ model.save_checkpoint(
+ save_dir=params.exp_dir,
+ tag=f"zero-epoch-{params.cur_epoch}",
+ client_state={},
+ exclude_frozen_parameters=True,
+ )
+ if rank == 0:
+ convert_zero_checkpoint_to_fp32_state_dict(
+ params.exp_dir,
+ f"{params.exp_dir}/epoch-{params.cur_epoch}",
+ tag=f"zero-epoch-{params.cur_epoch}",
+ exclude_frozen_parameters=True,
+ )
+ # save sampler state dict into checkpoint
+ # sampler_state_dict = train_dl.sampler.state_dict()
+ sampler_state_dict = train_dl.state_dict()
+ torch.save(
+ sampler_state_dict,
+ f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
+ )
+
+ os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
+
+ logging.info("Done!")
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ warnings.filterwarnings("ignore", category=FutureWarning)
+ run(rank=rank, world_size=world_size, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py
new file mode 100644
index 000000000..fad7f272c
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py
@@ -0,0 +1,433 @@
+import argparse
+import collections
+import json
+import logging
+import os
+import pathlib
+import random
+import re
+import subprocess
+from collections import defaultdict
+from dataclasses import dataclass
+from datetime import datetime
+from pathlib import Path
+from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
+from tqdm import tqdm
+import kaldialign
+import torch
+import torch.distributed as dist
+from torch.utils.tensorboard import SummaryWriter
+import numpy as np
+Pathlike = Union[str, Path]
+
+
+def get_world_size():
+ if "WORLD_SIZE" in os.environ:
+ return int(os.environ["WORLD_SIZE"])
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_world_size()
+ else:
+ return 1
+
+
+def get_rank():
+ if "RANK" in os.environ:
+ return int(os.environ["RANK"])
+ elif dist.is_available() and dist.is_initialized():
+ return dist.get_rank()
+ else:
+ return 0
+
+
+def get_local_rank():
+ if "LOCAL_RANK" in os.environ:
+ return int(os.environ["LOCAL_RANK"])
+ elif dist.is_available() and dist.is_initialized():
+ return dist.get_local_rank()
+ else:
+ return 0
+
+
+def str2bool(v):
+ """Used in argparse.ArgumentParser.add_argument to indicate
+ that a type is a bool type and user can enter
+
+ - yes, true, t, y, 1, to represent True
+ - no, false, f, n, 0, to represent False
+
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
+ """
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+class AttributeDict(dict):
+ def __getattr__(self, key):
+ if key in self:
+ return self[key]
+ raise AttributeError(f"No such attribute '{key}'")
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, key):
+ if key in self:
+ del self[key]
+ return
+ raise AttributeError(f"No such attribute '{key}'")
+
+ def __str__(self, indent: int = 2):
+ tmp = {}
+ for k, v in self.items():
+ # PosixPath is ont JSON serializable
+ if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
+ v = str(v)
+ tmp[k] = v
+ return json.dumps(tmp, indent=indent, sort_keys=True)
+
+
+def setup_logger(
+ log_filename: Pathlike,
+ log_level: str = "info",
+ use_console: bool = True,
+) -> None:
+ """Setup log level.
+
+ Args:
+ log_filename:
+ The filename to save the log.
+ log_level:
+ The log level to use, e.g., "debug", "info", "warning", "error",
+ "critical"
+ use_console:
+ True to also print logs to console.
+ """
+ now = datetime.now()
+ date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
+ if dist.is_available() and dist.is_initialized():
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
+ log_filename = f"{log_filename}-{date_time}-{rank}"
+ else:
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ log_filename = f"{log_filename}-{date_time}"
+
+ os.makedirs(os.path.dirname(log_filename), exist_ok=True)
+
+ level = logging.ERROR
+ if log_level == "debug":
+ level = logging.DEBUG
+ elif log_level == "info":
+ level = logging.INFO
+ elif log_level == "warning":
+ level = logging.WARNING
+ elif log_level == "critical":
+ level = logging.CRITICAL
+
+ logging.basicConfig(
+ filename=log_filename,
+ format=formatter,
+ level=level,
+ filemode="w",
+ force=True,
+ )
+ if use_console:
+ console = logging.StreamHandler()
+ console.setLevel(level)
+ console.setFormatter(logging.Formatter(formatter))
+ logging.getLogger("").addHandler(console)
+
+
+class MetricsTracker(collections.defaultdict):
+ def __init__(self):
+ # Passing the type 'int' to the base-class constructor
+ # makes undefined items default to int() which is zero.
+ # This class will play a role as metrics tracker.
+ # It can record many metrics, including but not limited to loss.
+ super(MetricsTracker, self).__init__(int)
+
+ def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v
+ for k, v in other.items():
+ if v - v == 0:
+ ans[k] = ans[k] + v
+ return ans
+
+ def __mul__(self, alpha: float) -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v * alpha
+ return ans
+
+ def __str__(self) -> str:
+ ans_frames = ""
+ ans_utterances = ""
+ for k, v in self.norm_items():
+ norm_value = "%.4g" % v
+ if "utt_" not in k:
+ ans_frames += str(k) + "=" + str(norm_value) + ", "
+ else:
+ ans_utterances += str(k) + "=" + str(norm_value)
+ if k == "utt_duration":
+ ans_utterances += " frames, "
+ elif k == "utt_pad_proportion":
+ ans_utterances += ", "
+ else:
+ raise ValueError(f"Unexpected key: {k}")
+ frames = "%.2f" % self["frames"]
+ ans_frames += "over " + str(frames) + " frames. "
+ if ans_utterances != "":
+ utterances = "%.2f" % self["utterances"]
+ ans_utterances += "over " + str(utterances) + " utterances."
+
+ return ans_frames + ans_utterances
+
+ def norm_items(self) -> List[Tuple[str, float]]:
+ """
+ Returns a list of pairs, like:
+ [('ctc_loss', 0.1), ('att_loss', 0.07)]
+ """
+ num_frames = self["frames"] if "frames" in self else 1
+ num_utterances = self["utterances"] if "utterances" in self else 1
+ ans = []
+ for k, v in self.items():
+ if k == "frames" or k == "utterances":
+ continue
+ norm_value = (
+ float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
+ )
+ ans.append((k, norm_value))
+ return ans
+
+ def reduce(self, device):
+ """
+ Reduce using torch.distributed, which I believe ensures that
+ all processes get the total.
+ """
+ keys = sorted(self.keys())
+ s = torch.tensor([float(self[k]) for k in keys], device=device)
+ dist.all_reduce(s, op=dist.ReduceOp.SUM)
+ for k, v in zip(keys, s.cpu().tolist()):
+ self[k] = v
+
+ def write_summary(
+ self,
+ tb_writer: SummaryWriter,
+ prefix: str,
+ batch_idx: int,
+ ) -> None:
+ """Add logging information to a TensorBoard writer.
+
+ Args:
+ tb_writer: a TensorBoard writer
+ prefix: a prefix for the name of the loss, e.g. "train/valid_",
+ or "train/current_"
+ batch_idx: The current batch index, used as the x-axis of the plot.
+ """
+ for k, v in self.norm_items():
+ tb_writer.add_scalar(prefix + k, v, batch_idx)
+
+
+def store_transcripts(
+ filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
+) -> None:
+ """Save predicted results and reference transcripts to a file.
+
+ Args:
+ filename:
+ File to save the results to.
+ texts:
+ An iterable of tuples. The first element is the cur_id, the second is
+ the reference transcript and the third element is the predicted result.
+ If it is a multi-talker ASR system, the ref and hyp may also be lists of
+ strings.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf8") as f:
+ for cut_id, ref, hyp in texts:
+ if char_level:
+ ref = list("".join(ref))
+ hyp = list("".join(hyp))
+ print(f"{cut_id}:\tref={ref}", file=f)
+ print(f"{cut_id}:\thyp={hyp}", file=f)
+
+
+def write_error_stats(
+ f: TextIO,
+ test_set_name: str,
+ results: List[Tuple[str, str]],
+ enable_log: bool = True,
+ compute_CER: bool = False,
+ sclite_mode: bool = False,
+) -> float:
+ """Write statistics based on predicted results and reference transcripts.
+
+ It will write the following to the given file:
+
+ - WER
+ - number of insertions, deletions, substitutions, corrects and total
+ reference words. For example::
+
+ Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
+ reference words (2337 correct)
+
+ - The difference between the reference transcript and predicted result.
+ An instance is given below::
+
+ THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
+
+ The above example shows that the reference word is `EDISON`,
+ but it is predicted to `ADDISON` (a substitution error).
+
+ Another example is::
+
+ FOR THE FIRST DAY (SIR->*) I THINK
+
+ The reference word `SIR` is missing in the predicted
+ results (a deletion error).
+ results:
+ An iterable of tuples. The first element is the cut_id, the second is
+ the reference transcript and the third element is the predicted result.
+ enable_log:
+ If True, also print detailed WER to the console.
+ Otherwise, it is written only to the given file.
+ Returns:
+ Return None.
+ """
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
+ ins: Dict[str, int] = defaultdict(int)
+ dels: Dict[str, int] = defaultdict(int)
+
+ # `words` stores counts per word, as follows:
+ # corr, ref_sub, hyp_sub, ins, dels
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
+ num_corr = 0
+ ERR = "*"
+
+ if compute_CER:
+ for i, res in enumerate(results):
+ cut_id, ref, hyp = res
+ ref = list("".join(ref))
+ hyp = list("".join(hyp))
+ results[i] = (cut_id, ref, hyp)
+
+ for cut_id, ref, hyp in results:
+ ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
+ for ref_word, hyp_word in ali:
+ if ref_word == ERR:
+ ins[hyp_word] += 1
+ words[hyp_word][3] += 1
+ elif hyp_word == ERR:
+ dels[ref_word] += 1
+ words[ref_word][4] += 1
+ elif hyp_word != ref_word:
+ subs[(ref_word, hyp_word)] += 1
+ words[ref_word][1] += 1
+ words[hyp_word][2] += 1
+ else:
+ words[ref_word][0] += 1
+ num_corr += 1
+ ref_len = sum([len(r) for _, r, _ in results])
+ sub_errs = sum(subs.values())
+ ins_errs = sum(ins.values())
+ del_errs = sum(dels.values())
+ tot_errs = sub_errs + ins_errs + del_errs
+ tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
+
+ if enable_log:
+ logging.info(
+ f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
+ f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
+ f"{del_errs} del, {sub_errs} sub ]"
+ )
+
+ print(f"%WER = {tot_err_rate}", file=f)
+ print(
+ f"Errors: {ins_errs} insertions, {del_errs} deletions, "
+ f"{sub_errs} substitutions, over {ref_len} reference "
+ f"words ({num_corr} correct)",
+ file=f,
+ )
+ print(
+ "Search below for sections starting with PER-UTT DETAILS:, "
+ "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
+ file=f,
+ )
+
+ print("", file=f)
+ print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
+ for cut_id, ref, hyp in results:
+ ali = kaldialign.align(ref, hyp, ERR)
+ combine_successive_errors = True
+ if combine_successive_errors:
+ ali = [[[x], [y]] for x, y in ali]
+ for i in range(len(ali) - 1):
+ if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
+ ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
+ ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
+ ali[i] = [[], []]
+ ali = [
+ [
+ list(filter(lambda a: a != ERR, x)),
+ list(filter(lambda a: a != ERR, y)),
+ ]
+ for x, y in ali
+ ]
+ ali = list(filter(lambda x: x != [[], []], ali))
+ ali = [
+ [
+ ERR if x == [] else " ".join(x),
+ ERR if y == [] else " ".join(y),
+ ]
+ for x, y in ali
+ ]
+
+ print(
+ f"{cut_id}:\t"
+ + " ".join(
+ (
+ ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+ for ref_word, hyp_word in ali
+ )
+ ),
+ file=f,
+ )
+
+ print("", file=f)
+ print("SUBSTITUTIONS: count ref -> hyp", file=f)
+
+ for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+ print(f"{count} {ref} -> {hyp}", file=f)
+
+ print("", file=f)
+ print("DELETIONS: count ref", file=f)
+ for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
+ print(f"{count} {ref}", file=f)
+
+ print("", file=f)
+ print("INSERTIONS: count hyp", file=f)
+ for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
+ print(f"{count} {hyp}", file=f)
+
+ print("", file=f)
+ print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
+ for _, word, counts in sorted(
+ [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
+ ):
+ (corr, ref_sub, hyp_sub, ins, dels) = counts
+ tot_errs = ref_sub + hyp_sub + ins + dels
+ ref_count = corr + ref_sub + dels
+ hyp_count = corr + hyp_sub + ins
+
+ print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
+ return float(tot_err_rate)
\ No newline at end of file
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py
new file mode 100644
index 000000000..562079044
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py
@@ -0,0 +1,434 @@
+# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
+import io
+import sys
+from argparse import ArgumentParser
+
+import gradio as gr
+import gradio.processing_utils as processing_utils
+import numpy as np
+import sherpa_onnx
+import soundfile as sf
+import torch
+import whisper
+#from cosyvoice.cli.cosyvoice import CosyVoice
+from gradio_client import utils as client_utils
+from model import SPEECH_LLM, EncoderProjector
+from peft import LoraConfig, get_peft_model
+from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
+from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
+# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+
+
+def get_model(params, device="cuda"):
+ """Load and prepare the speech-to-speech model."""
+ if params.remove_whisper_encoder_input_length_restriction:
+ replace_whisper_encoder_forward()
+
+ whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
+ speech_encoder = whisper_model.encoder
+ speech_encoder_dim = whisper_model.dims.n_audio_state
+ tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
+
+ if params.use_flash_attn:
+ attn_implementation = "flash_attention_2"
+ else:
+ attn_implementation = "eager"
+
+ llm = AutoModelForCausalLM.from_pretrained(
+ params.llm_path_or_name,
+ attn_implementation=attn_implementation,
+ torch_dtype=torch.float16,
+ )
+ if params.use_lora:
+ lora_config = LoraConfig(
+ r=64,
+ lora_alpha=16,
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "up_proj",
+ "gate_proj",
+ "down_proj",
+ ],
+ task_type="CAUSAL_LM",
+ )
+ llm = get_peft_model(llm, lora_config)
+ llm.print_trainable_parameters()
+
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
+ tokenizer.add_special_tokens(special_tokens_dict)
+ llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
+ llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
+ llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+
+ llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
+ DEFAULT_SPEECH_TOKEN
+ )
+
+ encoder_projector = EncoderProjector(
+ speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
+ )
+
+ # codec_vocab_size = 4096 + 4
+ codec_vocab_size = 6561 + 4
+ config = Qwen2Config(
+ vocab_size=codec_vocab_size,
+ hidden_size=1024,
+ num_hidden_layers=12,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ intermediate_size=2048,
+ max_position_embeddings=4096,
+ )
+ codec_lm = AutoModelForCausalLM.from_config(
+ config=config,
+ attn_implementation=attn_implementation,
+ torch_dtype=torch.float16,
+ )
+ codec_lm.resize_token_embeddings(codec_vocab_size)
+ codec_lm.vocab_size = codec_vocab_size
+ codec_lm.config.pad_token_id = codec_vocab_size - 1
+ codec_lm.config.eos_token_id = codec_vocab_size - 2
+ codec_lm.config.bos_token_id = codec_vocab_size - 3
+ codec_lm.config.mask_token_id = codec_vocab_size - 4
+
+ model = SPEECH_LLM(
+ speech_encoder,
+ llm,
+ encoder_projector,
+ codec_lm,
+ codec_lm_padding_side="left" if params.use_flash_attn else "right",
+ )
+
+ checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu")
+ model.load_state_dict(checkpoint, strict=False)
+
+ model.to(device)
+ model.eval()
+ return model, tokenizer
+
+
+def audio_decode_cosyvoice(audio_tokens, codec_decoder):
+ """
+ Generate audio from tokens with optional tone and prompt embedding.
+
+ Args:
+ audio_tokens (list): List of audio tokens to be processed.
+ codec_decoder: Codec decoder for generating audio.
+
+ Returns:
+ torch.Tensor: Generated audio waveform.
+ """
+ flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
+ prompt_speech_feat = torch.zeros(1, 0, 80)
+ tts_mel, _ = codec_decoder.model.flow.inference(
+ token=audio_tokens.to(codec_decoder.model.device),
+ token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+ codec_decoder.model.device
+ ),
+ prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
+ prompt_token_len=torch.tensor(
+ [flow_prompt_speech_token.shape[1]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
+ prompt_feat_len=torch.tensor(
+ [prompt_speech_feat.shape[1]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ embedding=flow_embedding.to(codec_decoder.model.device),
+ flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
+ )
+
+ audio_hat, _ = codec_decoder.model.hift.inference(
+ speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+ )
+
+ return audio_hat
+
+
+def preprocess(
+ messages,
+ tokenizer,
+):
+ """Preprocesses the data for supervised fine-tuning."""
+ texts = []
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ for i, msg in enumerate(messages):
+ texts.append(
+ tokenizer.apply_chat_template(
+ msg,
+ tokenize=True,
+ add_generation_prompt=False,
+ chat_template=TEMPLATE,
+ padding="longest",
+ truncation=False,
+ )
+ )
+ max_len_texts = max([len(text) for text in texts])
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
+ else:
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
+
+ input_ids = torch.tensor(texts, dtype=torch.int)
+
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+
+ return input_ids, attention_mask
+
+
+def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
+ def format_history(history: list):
+ messages = []
+ for item in history:
+ if isinstance(item["content"], str):
+ messages.append({"role": item["role"], "content": item["content"]})
+ return messages
+
+ def decode(
+ model,
+ token2wav_model,
+ tokenizer,
+ feature,
+ messages,
+ ):
+ """Decode one
+ Returns:
+ pass
+ """
+
+ dtype = torch.float32
+ device = model.llm.device
+
+ feature = feature.to(device, dtype=dtype)
+
+ input_ids, attention_mask = preprocess([messages], tokenizer)
+
+ generated_ids, audio_tokens = model.decode_with_speech_output(
+ feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
+ )
+
+ hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+
+ yield {"type": "text", "data": hyps[0]}
+
+ audio_tokens = [token for token in audio_tokens if token < 4096]
+ audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
+ audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
+ audio = audio_hat.squeeze(0).cpu().numpy()
+ audio = np.array(audio * 32767).astype(np.int16)
+ wav_io = io.BytesIO()
+ sf.write(wav_io, audio, samplerate=22050, format="WAV")
+ wav_io.seek(0)
+ wav_bytes = wav_io.getvalue()
+ audio_path = processing_utils.save_bytes_to_cache(
+ wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
+ )
+
+ yield {"type": "audio", "data": audio_path}
+
+ def media_predict(audio, history):
+ # First yield
+ yield (
+ None, # microphone
+ history, # media_chatbot
+ gr.update(visible=False), # submit_btn
+ gr.update(visible=True), # stop_btn
+ )
+ print(2333, history, audio)
+ history.append({"role": "user", "content": (audio,)})
+ history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
+ history.append({"role": "assistant", "content": ""})
+ formatted_history = format_history(
+ history=history
+ ) # only keep string text format
+
+ assert audio is not None
+ audio_transcript = get_transcript(
+ audio,
+ asr_model,
+ )
+ history[-2]["content"] = audio_transcript
+
+ fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
+ fbank = fbank.unsqueeze(0)
+ assert fbank.ndim == 3
+
+ for chunk in decode(
+ model, token2wav_model, tokenizer, fbank, formatted_history
+ ):
+ if chunk["type"] == "text":
+ history[-1]["content"] = chunk["data"]
+ yield (
+ None, # microphone
+ history, # media_chatbot
+ gr.update(visible=False), # submit_btn
+ gr.update(visible=True), # stop_btn
+ )
+ if chunk["type"] == "audio":
+ history.append(
+ {"role": "assistant", "content": gr.Audio(chunk["data"])}
+ )
+
+ # Final yield
+ yield (
+ None, # microphone
+ history, # media_chatbot
+ gr.update(visible=True), # submit_btn
+ gr.update(visible=False), # stop_btn
+ )
+
+ with gr.Blocks() as demo:
+ with gr.Tab("Online"):
+ with gr.Row():
+ with gr.Column(scale=1):
+ microphone = gr.Audio(sources=["microphone"], type="filepath")
+ submit_btn = gr.Button("Submit", variant="primary")
+ stop_btn = gr.Button("Stop", visible=False)
+ clear_btn = gr.Button("Clear History")
+ with gr.Column(scale=2):
+ media_chatbot = gr.Chatbot(height=650, type="messages")
+
+ def clear_history():
+ return [], gr.update(value=None)
+
+ submit_event = submit_btn.click(
+ fn=media_predict,
+ inputs=[
+ microphone,
+ media_chatbot,
+ ],
+ outputs=[microphone, media_chatbot, submit_btn, stop_btn],
+ )
+ stop_btn.click(
+ fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
+ inputs=None,
+ outputs=[submit_btn, stop_btn],
+ cancels=[submit_event],
+ queue=False,
+ )
+ clear_btn.click(
+ fn=clear_history, inputs=None, outputs=[media_chatbot, microphone]
+ )
+
+ demo.queue(default_concurrency_limit=100, max_size=100).launch(
+ max_threads=100,
+ ssr_mode=False,
+ share=args.share,
+ inbrowser=args.inbrowser,
+ server_port=args.server_port,
+ server_name=args.server_name,
+ )
+
+
+def _get_args():
+ parser = ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint-path",
+ type=str,
+ default=None,
+ help="Checkpoint name or path, default to %(default)r",
+ )
+ parser.add_argument(
+ "--token2wav-path",
+ type=str,
+ default=None,
+ help="Token2Wav path, default to %(default)r",
+ )
+ parser.add_argument(
+ "--asr-model-dir",
+ type=str,
+ default=None,
+ help="ASR model dir, default to %(default)r",
+ )
+ parser.add_argument(
+ "--flash-attn2",
+ action="store_true",
+ default=False,
+ help="Enable flash_attention_2 when loading the model.",
+ )
+ parser.add_argument(
+ "--share",
+ action="store_true",
+ default=False,
+ help="Create a publicly shareable link for the interface.",
+ )
+ parser.add_argument(
+ "--inbrowser",
+ action="store_true",
+ default=False,
+ help="Automatically launch the interface in a new tab on the default browser.",
+ )
+ parser.add_argument(
+ "--server-port", type=int, default=8001, help="Demo server port."
+ )
+ parser.add_argument(
+ "--server-name", type=str, default="127.0.0.1", help="Demo server name."
+ )
+ add_model_arguments(parser)
+ args = parser.parse_args()
+ return args
+
+
+def read_wave(wave_filename: str):
+ """
+ Args:
+ wave_filename:
+ Path to a wave file. It should be single channel and can be of type
+ 32-bit floating point PCM. Its sample rate does not need to be 24kHz.
+
+ Returns:
+ Return a tuple containing:
+ - A 1-D array of dtype np.float32 containing the samples,
+ which are normalized to the range [-1, 1].
+ - Sample rate of the wave file.
+ """
+
+ samples, sample_rate = sf.read(wave_filename, dtype="float32")
+ assert (
+ samples.ndim == 1
+ ), f"Expected single channel, but got {samples.ndim} channels."
+
+ samples_float32 = samples.astype(np.float32)
+
+ return samples_float32, sample_rate
+
+
+def get_transcript(audio_path, recognizer):
+ samples, sample_rate = read_wave(audio_path)
+ s = recognizer.create_stream()
+ s.accept_waveform(sample_rate, samples)
+ recognizer.decode_streams([s])
+ return s.result.text
+
+
+if __name__ == "__main__":
+ args = _get_args()
+ model, tokenizer = get_model(args)
+ token2wav = CosyVoice(
+ args.token2wav_path, load_jit=False, load_trt=False, fp16=False
+ )
+
+ asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
+ paraformer=f"{args.asr_model_dir}/model.int8.onnx",
+ tokens=f"{args.asr_model_dir}/tokens.txt",
+ num_threads=2,
+ sample_rate=16000,
+ feature_dim=80,
+ decoding_method="greedy_search",
+ debug=False,
+ )
+
+ _launch_demo(args, model, tokenizer, token2wav, asr_model)
diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/whisper_encoder_forward_monkey_patch.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/whisper_encoder_forward_monkey_patch.py
new file mode 120000
index 000000000..2a7808921
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/whisper_encoder_forward_monkey_patch.py
@@ -0,0 +1 @@
+../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
\ No newline at end of file