Merge 559f9e2deff33077461428d422d9f03c95988b01 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Yuekai Zhang 2025-07-24 22:09:54 +05:30 committed by GitHub
commit a5de488304
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 7058 additions and 0 deletions

View File

@ -0,0 +1,55 @@
# Introduction
This recipe includes scripts for training speech2speech models.
# SPEECH2SPEECH
The following table lists the folders for different tasks.
|Recipe | Speech Input | Speech Output | Comment|
|--------------|--------------|---------------|--------|
|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token |
### [Qwen-omni like Speech2speech Recipe](./qwen_omni)
[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset.
<br>
<p align="center">
<img src="assets/framework.png" width="800"/>
<p>
<br>
Command for training is:
```bash
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
```
Command for decoding is:
```bash
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
```
Please see [`prepare.sh`](./prepare.sh) for more details.

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

View File

@ -0,0 +1,234 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
# export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
set -eou pipefail
stage=$1
stop_stage=$2
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
if [ ! -L "/workspace/slam" ]; then
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
fi
log "stage 17: Training Speech2Speech Model, full parameters"
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
pretrained_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--max-duration 200 \
--enable-musan False \
--exp-dir $exp_dir \
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True --on-the-fly-speed-perturb False\
--deepspeed \
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True --on-the-fly-feats True \
--dataset vocalnet_ultrachat_voiceassistant_instruct_s2s --num-epochs 10 \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
$train_cmd_args
fi
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
# check if the link exists, if not exist, create it
if [ ! -L "/workspace/slam" ]; then
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
fi
log "stage 17: Training Speech2Speech Model, full parameters"
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
pretrained_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--max-duration 200 \
--enable-musan False \
--exp-dir $exp_dir \
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True --on-the-fly-speed-perturb False\
--deepspeed \
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True --on-the-fly-feats True \
--dataset vocalnet_ultrachat_voiceassistant_instruct_s2s_librispeech --num-epochs 10 \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
$train_cmd_args
fi
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "stage 19: Training TTS Model"
exp_dir=./qwen_omni/exp_tts_ultra_chat_voice_assistant
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_only_template
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_three_concat
pretrained_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
# --dataset ultra_chat_voice_assistant
train_cmd_args="--batch-size 30 \
--exp-dir $exp_dir \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--enable-speech-input False \
--deepspeed \
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/VoxBox/manifests_emilia_en \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--num-epochs 3 \
--use-lora False --unfreeze-llm False --enable-speech-output True"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train_tts.py \
$train_cmd_args
fi
# if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
# log "stage 20: Training TTS Model"
# echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
# if [ ! -L "/workspace/slam" ]; then
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
# fi
# exp_dir=./qwen_omni/exp_test
# ngpu=4
# latest_checkpoint_step=-1
# # Check if exp_dir exists and is a directory
# if [ -d "$exp_dir" ]; then
# # List directories matching checkpoint-* and find the one with the largest step number
# for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
# checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# # Extract step number using parameter expansion
# current_step=${checkpoint_name#checkpoint-}
# # Ensure current_step is a number
# if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
# latest_checkpoint_step=$current_step
# fi
# done
# fi
# train_cmd_args="--max-duration 150 \
# --enable-musan False \
# --exp-dir $exp_dir \
# --speech-encoder-path-or-name models/large-v2.pt \
# --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
# --dataset vocalnet_ultrachat_voiceassistant \
# --manifest-dir data/fbank \
# --deepspeed \
# --deepspeed_config ./qwen_omni/ds_config_zero1.json \
# --use-flash-attn True --on-the-fly-feats True \
# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
# if [ "$latest_checkpoint_step" -ge 0 ]; then
# log "Continuing training from checkpoint-$latest_checkpoint_step"
# step=$latest_checkpoint_step
# train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
# else
# log "Starting training from scratch as no checkpoint was found in $exp_dir"
# # No pretrained model or sampler state dict needed for the first run
# fi
# torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
# $train_cmd_args
# fi
# if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
# log "stage 21: TTS Decoding Test Set"
# exp_dir=./qwen_omni/exp_tts
# torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
# --exp-dir $exp_dir \
# --speech-encoder-path-or-name models/large-v2.pt \
# --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
# --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
# --use-flash-attn True \
# --enable-speech-output True \
# --token2wav-path /workspace/CosyVoice2-0.5B \
# --use-lora True
# fi

View File

@ -0,0 +1,291 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
# Copyright 2025 Nvidia (Yuekai Zhang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank \
--huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
--audio-key question_audio --text-key answer \
--prefix ultrachat
"""
import argparse
import logging
from pathlib import Path
import torch
from datasets import load_dataset
from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig
from vocalnet_lhotse_cutset import LazyCustomDatasetIterator
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=True,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
parser.add_argument(
"--resample-to-16kHz",
type=str2bool,
default=True,
help="Resample audio to 16kHz. Default: False.",
)
parser.add_argument(
"--speed-perturb",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--out-dir",
type=str,
default="data/fbank",
help="Output directory for the computed features",
)
parser.add_argument(
"--huggingface-dataset-path-or-name",
type=str,
default="/workspace/Belle_1.4M-SLAM-Omni",
help="The path or name of the Huggingface dataset",
)
parser.add_argument(
"--audio-key",
type=str,
default="question_audio",
help="The key in the Huggingface dataset containing the audio data",
)
parser.add_argument(
"--text-key",
type=str,
default="answer",
help="The key in the Huggingface dataset containing the text data",
)
parser.add_argument(
"--prefix",
type=str,
default="belle",
help="""The dataset prefix to use when saving the features""",
)
parser.add_argument(
"--json-file-path",
type=str,
default=None,
help="The path to the json file containing the vocalnet data",
)
parser.add_argument(
"--drop-recordings",
type=str2bool,
default=True,
help="Drop recordings. Default: False.",
)
parser.add_argument(
"--subset",
type=str,
default=None,
help="The subset to use from the Huggingface dataset",
)
parser.add_argument(
"--split",
type=str,
default="train",
help="The split to use from the Huggingface dataset",
)
return parser
def remove_short_and_long_utt(c):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 50.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
def compute_fbank(args):
in_out_dir = Path(args.out_dir)
in_out_dir.mkdir(parents=True, exist_ok=True)
# number of workers in dataloader
num_workers = 4
# number of seconds in a batch
batch_duration = 10
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
)
else:
raise NotImplementedError("Only WhisperFbank is implemented.")
logging.info(f"device: {device}")
dataset = load_dataset(
args.huggingface_dataset_path_or_name,
args.subset,
streaming=True,
split=args.split,
)
num_shards = dataset.num_shards
num_digits = 5
for i in range(252, num_shards):
shard = dataset.shard(num_shards, i)
# shard = shard.take(10) # for testing
logging.info(
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
)
idx = f"{i}".zfill(num_digits)
cut_set = CutSet.from_huggingface_dataset(
shard, audio_key=args.audio_key, text_key=args.text_key
)
cut_set = cut_set.filter(remove_short_and_long_utt)
if args.resample_to_16kHz:
cut_set = cut_set.resample(16000)
if args.speed_perturb:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{idx}_{args.subset}",
num_workers=num_workers,
batch_duration=batch_duration,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
# cut_set = cut_set.trim_to_supervisions(
# keep_overlapping=False, min_duration=None
# )
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.{args.subset}.jsonl.gz"
logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125
if args.drop_recordings:
cut_set.drop_recordings().to_file(cuts_path)
else:
cut_set.to_file(cuts_path)
def compute_fbank_vocalnet(args):
in_out_dir = Path(args.out_dir)
in_out_dir.mkdir(parents=True, exist_ok=True)
# number of workers in dataloader
num_workers = 4
# number of seconds in a batch
batch_duration = 10
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
)
else:
raise NotImplementedError("Only WhisperFbank is implemented.")
logging.info(f"device: {device}")
num_shards = 50
num_digits = 5
for i in range(num_shards):
logging.info(f"Processing shard {i}")
idx = f"{i}".zfill(num_digits)
cut_set = CutSet(
LazyCustomDatasetIterator(
json_file_path=args.json_file_path, shard_id=i, num_shards=num_shards
)
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
if args.resample_to_16kHz:
cut_set = cut_set.resample(16000)
if args.speed_perturb:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{idx}",
num_workers=num_workers,
batch_duration=batch_duration,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz"
logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125
cut_set.to_file(cuts_path)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
if args.json_file_path is not None:
compute_fbank_vocalnet(args)
else:
compute_fbank(args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,99 @@
# https://huggingface.co/datasets/VocalNet/UltraChat-vocalnet/blob/main/UltraChat.json
# https://huggingface.co/datasets/VocalNet/VoiceAssistant-430K-vocalnet/blob/main/VoiceAssistant-430K.json
import json
import os
import numpy as np
from lhotse import CutSet
from lhotse.audio import Recording
from lhotse.supervision import SupervisionSegment
class LazyCustomDatasetIterator:
"""
Thin wrapper on top of HF datasets objects that allows to interact with them through a Lhotse CutSet.
It can be initialized with an existing HF dataset, or args/kwargs passed on to ``datasets.load_dataset()``.
Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples
returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively.
The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary.
Example with existing HF dataset::
>>> import datasets
... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")
... dataset = dataset.map(some_transform)
... cuts_it = LazyHFDatasetIterator(dataset)
... for cut in cuts_it:
... pass
Example providing HF dataset init args/kwargs::
>>> import datasets
... cuts_it = LazyHFDatasetIterator("mozilla-foundation/common_voice_11_0", "hi", split="test")
... for cut in cuts_it:
... pass
"""
def __init__(self, json_file_path: str, shard_id: int = 0, num_shards: int = 100):
self.json_file_path = json_file_path
self.shard_id = shard_id
self.num_shards = num_shards
def __iter__(self):
with open(self.json_file_path, "r", encoding="utf-8") as f:
list_data_dict = json.load(f)
list_data_dict = list_data_dict[self.shard_id :: self.num_shards]
for item in list_data_dict:
custom_data = item.copy()
json_file_parent_of_parent_dir = os.path.dirname(
os.path.dirname(self.json_file_path)
)
units_path = os.path.join(
json_file_parent_of_parent_dir, custom_data["units"]
)
speech_token_dict = np.load(units_path, allow_pickle=True).item()
speech_token = speech_token_dict["speech_token"].squeeze(0).tolist()
speech_token_len = speech_token_dict["speech_token_len"]
assert len(speech_token) == speech_token_len
custom_data["speech_token"] = speech_token
audio_path = custom_data.pop("speech", None)
audio_path = os.path.join(json_file_parent_of_parent_dir, audio_path)
item_id = item.get("id")
recording = Recording.from_file(path=audio_path, recording_id=item_id)
conversations = item.get("conversations")
assert isinstance(conversations, list) and len(conversations) == 2
for conv in conversations:
if isinstance(conv, dict) and conv.get("from") == "gpt":
gpt_text = conv.get("value")
break
assert gpt_text is not None
supervision = SupervisionSegment(
id=item_id,
recording_id=recording.id,
start=0.0, # Assuming the supervision covers the entire recording
duration=recording.duration,
text=gpt_text,
)
cut = recording.to_cut()
# cut.id will be the same as recording.id
cut.supervisions = [supervision]
# custom_data contains the original item's fields, minus "speech".
# So, "id", "conversations", "units", etc., are preserved here.
custom_data.pop("conversations")
custom_data.pop("units")
cut.custom = custom_data
yield cut
if __name__ == "__main__":
json_file_path = (
"/workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json"
)
cut_set = CutSet(LazyCustomDatasetIterator(json_file_path=json_file_path))
for cut in cut_set:
print(cut)
input()

View File

@ -0,0 +1,444 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
set -eou pipefail
stage=$1
stop_stage=$2
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Clone CosyVoice repo and install requirements inside the container"
# docker: ghcr.io/swivid/f5-tts:main
pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
cd /workspace/CosyVoice
# If you failed to clone submodule due to network failures, please run following command until success
git submodule update --init --recursive
pip install -r qwen_omni/requirements.txt
pip install -r qwen_omni/requirements-cosyvoice.txt
# For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Cosyvoice pretrained model for speech token2wav module
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
# For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
pip install sherpa-onnx
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
fi
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Compute fbank feature from huggingface"
python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_test \
--huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
--audio-key question_audio --text-key answer \
--prefix belle
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Combine features"
manifest_dir=data/fbank
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
# exclude cust_belle_00000.jsonl.gz for valid and test set
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
echo $pieces | wc
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
fi
fi
ngpu=8
exp_dir=./qwen_omni/exp_speech2speech
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: Training Speech2Speech Model"
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Decoding, only support batch_size=1 for now."
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: Gradio Demo"
python3 ./qwen_omni/web_demo.py \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-999.pt \
--use-flash-attn True \
--enable-speech-output True \
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 6: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_voice_assistant \
# --huggingface-dataset-path-or-name worstchan/VoiceAssistant-400K-SLAM-Omni \
# --audio-key question_audio --text-key answer \
# --prefix voice_assistant
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_voice_assistant_cosy2 \
--json-file-path /workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json \
--prefix voice_assistant
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "stage 7: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_ultrachat \
# --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
# --audio-key question_audio --text-key answer \
# --prefix ultrachat
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_ultrachat_cosy2 \
--json-file-path /workspace/slam/UltraChat-vocalnet/UltraChat.json \
--prefix ultrachat
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "stage 8: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_gigaspeech \
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
--subset test --split test \
--audio-key audio --text-key text \
--prefix gigaspeech
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
--out-dir data/fbank_gigaspeech \
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
--subset xl --split train \
--audio-key audio --text-key text \
--prefix gigaspeech
fi
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
ngpu=4
exp_dir=./qwen_omni/exp_speech2speech_en
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "stage 10: Training Speech2Speech Model"
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 150 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--dataset-format vocalnet \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True --on-the-fly-feats True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "stage 11: Decoding EN, val set only support batch_size=1 for now."
exp_dir=./qwen_omni/exp_speech2speech_en_continue
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 997 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch4_speech2speech \
--enable-speech-output True \
--token2wav-path /workspace/CosyVoice2-0.5B \
--use-lora True
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "stage 12: Decoding EN voicebench"
exp_dir=./qwen_omni/exp_speech2speech_en_continue
torchrun --nproc_per_node=2 \
./qwen_omni/decode_dist.py \
--output-dir $exp_dir/log_voicebench \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--use-flash-attn True \
--enable-speech-output True \
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
--use-lora True --subset-name openbookqa --split-name test
fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "stage 13: Server"
exp_dir=./qwen_omni/exp_speech2speech_en_continue
python3 ./qwen_omni/server.py \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output True \
--use-lora True
fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "stage 14: Client"
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
# exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
# The final assignment of datasets in the original script is used here:
# (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu")
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh")
declare -a target_datasets=("mmsu")
NUM_CLIENT_JOBS=4 # Number of parallel client jobs
BASE_PORT=8000 # Base port for servers
log "Starting $NUM_CLIENT_JOBS parallel client jobs to process ${#target_datasets[@]} datasets."
for job_id in $(seq 0 $(($NUM_CLIENT_JOBS - 1)))
do
( # Start a subshell for backgrounding this client job's tasks
current_port=$(expr $BASE_PORT + $job_id)
log "Client Job $job_id: Initializing. Will connect to port $current_port."
processed_count_for_this_job=0
# Iterate over all datasets using their indices
for i in "${!target_datasets[@]}"; do
# Assign dataset to job_id in a round-robin fashion
if [ $(($i % $NUM_CLIENT_JOBS)) -eq $job_id ]; then
dataset="${target_datasets[$i]}"
# local split_name # Determine split_name based on dataset
if [ "$dataset" == "sd-qa" ]; then
split_name="usa"
else
split_name="test"
fi
log "Client Job $job_id (Port $current_port): Processing dataset '$dataset' (split '$split_name')"
python3 ./qwen_omni/client.py \
--subset-name "$dataset" \
--split-name "$split_name" \
--output-dir "$exp_dir/results" \
--port "$current_port" # Assuming client.py accepts --port
if [ $? -ne 0 ]; then
log "Client Job $job_id (Port $current_port): ERROR processing dataset '$dataset'."
fi
processed_count_for_this_job=$(($processed_count_for_this_job + 1))
fi
done
log "Client Job $job_id (Port $current_port): Finished. Processed $processed_count_for_this_job datasets."
) & # Run this client job's subshell in the background
done
log "All client jobs launched. Waiting for completion..."
wait # Wait for all backgrounded client jobs to complete
log "All client jobs have completed."
fi
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
log "stage 15: Training Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text
ngpu=2
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 700 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--start-epoch 4 --pretrained-model-path $exp_dir/epoch-3/pytorch_model.bin \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
fi
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
log "stage 16: Training Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--max-duration 800 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True \
--deepspeed \
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
$train_cmd_args
fi
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
# pip install gradio sherpa-onnx
log "stage 17: Server for adapter only speech continuation"
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
N_GPUS=4 # Define the number of GPUs/processes you want to launch
for id in $(seq 0 $(($N_GPUS - 1)))
do
log "Launching server on GPU $id with port $(expr 8000 + $id)"
CUDA_VISIBLE_DEVICES=$id python3 ./qwen_omni/server.py \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/checkpoint-55276/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output False \
--port $(expr 18000 + $id) \
--use-lora True &
done
wait # Wait for all background processes to complete
fi
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
log "stage 18: Training kl-div Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text_kl
ngpu=2
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 700 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--loss-type kl_div --dataset librispeech \
--pretrained-model-path $exp_dir/checkpoint-1001/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-1001/sampler.pt \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
fi
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "stage 19: Server for kl loss"
exp_dir=./qwen_omni/exp_speech2text_kl
python3 ./qwen_omni/server.py \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output False \
--use-lora False --prompt-template qa
fi
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "stage 20: Training Speech2Speech Model, adaptor + lora, second stage"
exp_dir=./qwen_omni/exp_speech2text_kl_llm
pretrained_dir=./qwen_omni/exp_speech2text_kl
ngpu=2
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 200 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path $pretrained_dir/epoch-10/pytorch_model.bin \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False --dataset-format vocalnet
fi

View File

@ -0,0 +1,156 @@
# client.py
import argparse
import json
import os
import requests
from datasets import concatenate_datasets, load_dataset
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(description="Speech-to-Text Client")
parser.add_argument(
"--server-url",
type=str,
default="http://localhost",
help="URL of the FastAPI server",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port of the FastAPI server",
)
parser.add_argument(
"--dataset-name",
type=str,
default="hlt-lab/voicebench",
help="Hugging Face dataset name",
)
parser.add_argument(
"--subset-name",
type=str,
default="commoneval", # Adjust as needed
help="Dataset subset name",
)
parser.add_argument(
"--split-name",
type=str,
default=None, # Adjust as needed
help="Dataset split name",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="Directory to save results"
)
args = parser.parse_args()
return args
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
output_filename = os.path.join(
args.output_dir,
f"{args.subset_name}-{args.split_name}.jsonl",
)
server_decode_url = f"{args.server_url}:{args.port}/decode"
print("Loading dataset...")
if args.subset_name != "mmsu":
dataset = load_dataset(
args.dataset_name,
args.subset_name,
split=args.split_name,
trust_remote_code=True,
)
else:
# load all splits and concatenate them
dataset = load_dataset(
args.dataset_name,
args.subset_name,
trust_remote_code=True,
)
dataset = concatenate_datasets([dataset[subset] for subset in dataset])
print(f"Dataset loaded with {len(dataset)} samples.")
print(f"Sending requests to {server_decode_url}...")
print(f"Saving results to {output_filename}")
with open(output_filename, "w", encoding="utf-8") as outfile:
# Iterate directly over the dataset
progress_bar = tqdm(dataset, desc="Processing", unit="samples")
for item in progress_bar:
audio_info = item.get("audio")
assert (
audio_info["sampling_rate"] == 16000
), f"Sampling rate is {audio_info['sampling_rate']}, not 16khz"
# Prepare data for JSON serialization and server request
audio_array = audio_info["array"].tolist() # Convert numpy array to list
result_dict = {}
for key in item.keys():
if key != "audio":
# Ensure other fields are JSON serializable
try:
# Attempt to serialize to catch issues early (optional)
json.dumps(item[key])
result_dict[key] = item[key]
except (TypeError, OverflowError):
print(
f"Warning: Converting non-serializable key '{key}' to string."
)
result_dict[key] = str(
item[key]
) # Convert problematic types to string
payload = {
"audio": audio_array,
"sampling_rate": 16000,
}
try:
response = requests.post(server_decode_url, json=payload, timeout=60)
response.raise_for_status()
server_response = response.json()
decoded_text = server_response.get("text", "")
# Add the response to the result dictionary
result_dict["response"] = decoded_text
print(result_dict)
# Write result to JSONL file
json.dump(result_dict, outfile, ensure_ascii=False)
outfile.write("\n")
except requests.exceptions.RequestException as e:
print(f"\nError sending request for an item: {e}")
error_entry = result_dict # Use the data prepared so far
error_entry["error"] = str(e)
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
except json.JSONDecodeError:
print("\nError decoding server response for an item.")
error_entry = result_dict
error_entry["error"] = "Invalid JSON response from server"
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
except Exception as e:
print(f"\nUnexpected error processing an item: {e}")
error_entry = result_dict
error_entry["error"] = f"Unexpected error: {str(e)}"
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
# Progress bar updates automatically by iterating over tqdm(dataset)
# No need to close progress_bar explicitly when iterating directly
print("Processing finished.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,813 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from datasets import interleave_datasets, load_dataset, Audio, Features, Value, Sequence
from lhotse import (
CutSet,
WhisperFbank,
WhisperFbankConfig,
load_manifest,
load_manifest_lazy,
)
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PerturbSpeed,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from utils import get_local_rank, str2bool
import io
import wave
import random
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=300.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--on-the-fly-speed-perturb",
type=str2bool,
default=True,
help="When enabled, use on-the-fly speed perturbation. "
"Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=4,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--huggingface-dataset-path-or-name",
type=str,
default=None,
help="The path or name of the Huggingface dataset",
)
group.add_argument(
"--audio-key",
type=str,
default=None,
help="The key in the Huggingface dataset containing the audio data",
)
group.add_argument(
"--text-key",
type=str,
default=None,
help="The key in the Huggingface dataset containing the text data",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.on_the_fly_speed_perturb and self.args.on_the_fly_feats:
transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
rank = get_local_rank()
train = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 1000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True if self.args.num_workers > 0 else False,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
"""
Args:
cuts_valid:
CutSet for validation.
"""
logging.info("About to create dev dataset")
rank = get_local_rank()
validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
else:
valid_sampler = SimpleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_num_workers = 1
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=valid_num_workers,
persistent_workers=True if valid_num_workers > 0 else False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def test_cuts_belle(self) -> CutSet:
logging.info("About to get test cuts")
return {
"test": load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
}
@lru_cache()
def dev_cuts_belle(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
@lru_cache()
def train_cuts_belle(self) -> CutSet:
logging.info("About to get train cuts")
slam_omni_zh_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_train.jsonl.gz"
)
return slam_omni_zh_cuts
@lru_cache()
def train_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get train cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
)
ultrachat_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
)
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
ultrachat_cuts = ultrachat_cuts.resample(16000)
return CutSet.mux(
VoiceAssistant_cuts,
ultrachat_cuts,
weights=[
len(VoiceAssistant_cuts),
len(ultrachat_cuts),
],
)
@lru_cache()
def valid_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get valid cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
)
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
return VoiceAssistant_cuts
@lru_cache()
def test_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get test cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
)
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
return {"test": VoiceAssistant_cuts}
@lru_cache()
def train_cuts_ultravox(self) -> CutSet:
logging.info("About to get train cuts")
if self.args.huggingface_dataset_path_or_name is not None:
librispeech_path = (
self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
)
people_speech_path = (
self.args.huggingface_dataset_path_or_name + "/peoples_speech"
)
gigaspeech_path = self.args.huggingface_dataset_path_or_name + "/gigaspeech"
else:
librispeech_path = "fixie-ai/librispeech_asr"
people_speech_path = "fixie-ai/peoples_speech"
gigaspeech_path = "fixie-ai/gigaspeech"
# 148_688
librispeech_other = load_dataset(
librispeech_path, "other", split="train.500", streaming=True
)
# 104_014
librispeech_clean_360 = load_dataset(
librispeech_path, "clean", split="train.360", streaming=True
)
# 28_539
librispeech_clean_100 = load_dataset(
librispeech_path, "clean", split="train.100", streaming=True
)
# 1_501_271
people_speech_clean = load_dataset(
people_speech_path, "clean", split="train", streaming=True
)
# 548_000
people_speech_dirty_sa = load_dataset(
people_speech_path, "dirty_sa", split="train", streaming=True
)
# 8_266_422
gigaspeech = load_dataset(
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
)
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_100,
audio_key="audio",
text_key="text",
)
librispeech_other_cuts = CutSet.from_huggingface_dataset(
librispeech_other,
audio_key="audio",
text_key="text",
)
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_360,
audio_key="audio",
text_key="text",
)
gigaspeech_cuts = CutSet.from_huggingface_dataset(
gigaspeech, audio_key="audio", text_key="text"
)
people_speech_clean_cuts = CutSet.from_huggingface_dataset(
people_speech_clean,
audio_key="audio",
text_key="text",
)
people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
people_speech_dirty_sa,
audio_key="audio",
text_key="text",
)
return CutSet.mux(
librispeech_clean_100_cuts,
librispeech_clean_360_cuts,
librispeech_other_cuts,
gigaspeech_cuts,
people_speech_clean_cuts,
people_speech_dirty_sa_cuts,
weights=[
28539,
104014,
148688,
8266422,
1501271,
548000,
],
)
@lru_cache()
def valid_cuts_ultravox(self) -> CutSet:
logging.info("About to get valid cuts")
librispeech_path = "fixie-ai/librispeech_asr"
librispeech_clean_valid = load_dataset(
librispeech_path, "clean", split="validation", streaming=True
)
librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_valid,
audio_key="audio",
text_key="text",
)
return librispeech_clean_valid_cuts
@lru_cache()
def train_cuts_librispeech(self) -> CutSet:
logging.info("About to get train cuts")
if self.args.huggingface_dataset_path_or_name is not None:
librispeech_path = self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
else:
librispeech_path = "fixie-ai/librispeech_asr"
# 148_688
librispeech_other = load_dataset(
librispeech_path, "other", split="train.500", streaming=True
)
# 104_014
librispeech_clean_360 = load_dataset(
librispeech_path, "clean", split="train.360", streaming=True
)
# 28_539
librispeech_clean_100 = load_dataset(
librispeech_path, "clean", split="train.100", streaming=True
)
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_100,
audio_key="audio",
text_key="text",
)
librispeech_other_cuts = CutSet.from_huggingface_dataset(
librispeech_other,
audio_key="audio",
text_key="text",
)
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_360,
audio_key="audio",
text_key="text",
)
return CutSet.mux(
librispeech_clean_100_cuts,
librispeech_clean_360_cuts,
librispeech_other_cuts,
weights=[
28539,
104014,
148688,
],
)
@lru_cache()
def train_cuts_gigaspeech(self) -> CutSet:
logging.info("About to get train cuts")
gigaspeech_path = "fixie-ai/gigaspeech"
gigaspeech = load_dataset(
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
)
gigaspeech_cuts = CutSet.from_huggingface_dataset(
gigaspeech, audio_key="audio", text_key="text"
)
return gigaspeech_cuts
@lru_cache()
def train_cuts_instruct_s2s(self) -> CutSet:
logging.info("About to get train cuts")
if self.args.huggingface_dataset_path_or_name is not None:
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
else:
data_path = "yuekai/InstructS2S-200K"
# 148_688
instruct_s2s_train = load_dataset(
data_path, split="train", streaming=True
)
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
instruct_s2s_train,
audio_key="question_audio",
text_key="answer",
)
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
return instruct_s2s_train_cuts
@lru_cache()
def train_cuts_en_speech2speech(self) -> CutSet:
logging.info("About to get train cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
)
ultrachat_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
)
if self.args.huggingface_dataset_path_or_name is not None:
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
else:
data_path = "yuekai/InstructS2S-200K"
# 148_688
instruct_s2s_train = load_dataset(
data_path, split="train", streaming=True
)
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
instruct_s2s_train,
audio_key="question_audio",
text_key="answer",
)
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
return CutSet.mux(
VoiceAssistant_cuts,
ultrachat_cuts,
instruct_s2s_train_cuts,
weights=[
len(VoiceAssistant_cuts),
len(ultrachat_cuts),
423_000,
],
)
@lru_cache()
def train_cuts_en_speech2speech_librispeech(self) -> CutSet:
logging.info("About to get train cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
)
ultrachat_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
)
if self.args.huggingface_dataset_path_or_name is not None:
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
else:
data_path = "yuekai/InstructS2S-200K"
# 148_688
instruct_s2s_train = load_dataset(
data_path, split="train", streaming=True
)
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
instruct_s2s_train,
audio_key="question_audio",
text_key="answer",
)
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
if self.args.huggingface_dataset_path_or_name is not None:
librispeech_path = self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
else:
librispeech_path = "fixie-ai/librispeech_asr"
# 148_688
librispeech_other = load_dataset(
librispeech_path, "other", split="train.500", streaming=True
)
# 104_014
librispeech_clean_360 = load_dataset(
librispeech_path, "clean", split="train.360", streaming=True
)
# 28_539
librispeech_clean_100 = load_dataset(
librispeech_path, "clean", split="train.100", streaming=True
)
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_100,
audio_key="audio",
text_key="text",
)
librispeech_other_cuts = CutSet.from_huggingface_dataset(
librispeech_other,
audio_key="audio",
text_key="text",
)
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_360,
audio_key="audio",
text_key="text",
)
return CutSet.mux(
librispeech_other_cuts,
VoiceAssistant_cuts,
ultrachat_cuts,
librispeech_clean_360_cuts,
instruct_s2s_train_cuts,
librispeech_clean_100_cuts,
weights=[
148688,
len(VoiceAssistant_cuts),
len(ultrachat_cuts),
104014,
423_000,
28539,
],
)
@lru_cache()
def train_cuts_emilia_en(self) -> CutSet:
logging.info("About to get train cuts")
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
# if self.args.huggingface_dataset_path_or_name is not None:
# data_path = self.args.huggingface_dataset_path_or_name + "/emilia_en"
# else:
# data_path = "yuekai/emilia_en"
emilia_en_data = load_dataset(
data_path, split="train", streaming=True
)
def update_wav_path(example):
sampling_rate = 16000 # From current_features
duration = 1 # seconds, arbitrary duration for random audio
num_channels = 1 # mono
sample_width = 2 # 2 bytes = 16-bit audio
num_frames = int(duration * sampling_rate)
# Generate random bytes for the PCM data part
# This will be random noise, but structurally valid for a WAV file
pcm_data = bytes([random.randint(0, 255) for _ in range(num_frames * num_channels * sample_width)])
# Create a WAV file in memory
audio_buffer = io.BytesIO()
with wave.open(audio_buffer, 'wb') as wf:
wf.setnchannels(num_channels)
wf.setsampwidth(sample_width)
wf.setframerate(sampling_rate)
wf.writeframes(pcm_data) # writeframes expects bytes
example["wav"] = audio_buffer.getvalue()
return example
emilia_en_data = emilia_en_data.map(update_wav_path)
current_features = Features({
'id': Value('string'),
'text': Value('string'),
'duration': Value('float'),
'language': Value('string'),
'dnsmos': Value('float'),
'speech_token': Sequence(Value('int32')),
'wav': Audio(sampling_rate=16000)
})
emilia_en_data = emilia_en_data.rename_column("code", "speech_token")
emilia_en_data = emilia_en_data.cast(current_features)
emilia_en_train_cuts = CutSet.from_huggingface_dataset(
emilia_en_data, # Adjusted from instruct_s2s_train
audio_key="wav",
text_key="text",
)
return emilia_en_train_cuts

View File

@ -0,0 +1,759 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Wei Kang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# Command for decoding using fine-tuned models:
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Cosyvoice pretrained model for speech token2wav module
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
"""
import argparse
import logging
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import soundfile as sf
import torch
import torch.nn as nn
import transformers
import whisper
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from data_module import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from utils import AttributeDict, setup_logger, store_transcripts, write_error_stats
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
Args:
audio_tokens (list): List of audio tokens to be processed.
model_config: Configuration object containing vocab settings.
codec_decoder: Codec decoder for generating audio.
tone_dir (str): The tone directory or setting.
audio_prompt_path (str, optional): Path to the audio prompt file. Required when tone_dir is not "default_tone".
code_layer (int, optional): Number of code layers. Defaults to 1.
num_latency_tokens (int, optional): Number of latency tokens to ignore. Defaults to 0.
speed (float, optional): Speed factor for audio generation. Defaults to 1.0.
Returns:
torch.Tensor: Generated audio waveform.
"""
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
"""
Generate audio from tokens with optional tone and prompt embedding.
Args:
audio_tokens (list): List of audio tokens to be processed.
codec_decoder: Codec decoder for generating audio.
Returns:
torch.Tensor: Generated audio waveform.
"""
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
prompt_speech_feat = torch.zeros(1, 0, 80)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
prompt_token_len=torch.tensor(
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
prompt_feat_len=torch.tensor(
[prompt_speech_feat.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
embedding=flow_embedding.to(codec_decoder.model.device),
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def get_model(params, device):
"""Load and prepare the speech-to-speech model."""
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
# TODO: FIX ME
# codec_vocab_size = 4096 + 4
codec_vocab_size = 6561 + 4
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=2048,
max_position_embeddings=4096,
)
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4
else:
codec_lm = None
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side="left" if params.use_flash_attn else "right",
)
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
for epoch in range(start, params.epoch + 1)
]
avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
return model, tokenizer
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names:
avg[k] += state_dict[k]
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else:
avg[k] //= n
return avg
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=-1,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="beam-search",
help="""Decoding method.
Supported values are:
- beam-search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=1,
help="beam size for beam search decoding",
)
parser.add_argument(
"--exp-dir",
type=str,
default="whisper/exp",
help="The experiment dir",
)
parser.add_argument(
"--token2wav-path",
type=str,
default="/workspace/CosyVoice-300M-SFT",
help="The path to the token2wav model",
)
parser.add_argument(
"--prompt_text",
type=str,
default="Romeo and Juliet might be the most famous act of William Shakespeare.",
help="The prompt text",
)
parser.add_argument(
"--prompt_speech_path",
type=str,
default="./assets/common_voice_en_2586258.wav",
help="The path to the prompt speech",
)
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
params = AttributeDict({})
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
token2wav_model: nn.Module,
batch: dict,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: "beam-search"
- value: A list of lists. Each sublist is a list of token IDs.
Args:
params:
It is returned by :func:`get_params`.
model:
The neural model.
batch:
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
Returns:
Return a dict, whose key may be "beam-search".
"""
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="longest",
truncation=False,
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
dtype = torch.float32
device = model.llm.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
if not params.remove_whisper_encoder_input_length_restriction:
T = 3000
if feature.shape[2] < T:
feature = torch.cat(
[
feature,
torch.zeros(
feature.shape[0], feature.shape[1], T - feature.shape[2]
).to(device, dtype=dtype),
],
2,
)
# chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
# questions_with_history = [
# cut.custom["question"] for cut in batch["supervisions"]["cut"]
# ]
# history_contexts = [
# question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
# ]
# last_questions = [
# question.split("<USER>: ")[-1].strip() for question in questions_with_history
# ]
# messages = []
# for i, total_round in enumerate(chat_rounds):
# message = []
# if total_round > 1:
# history_question_answer = history_contexts[i].split("USER:")
# history_question_answer = [item for item in history_question_answer if item]
# for j in range(total_round - 1):
# question_answer = history_question_answer[j].split("ASSISTANT:")
# message += [
# {"role": "user", "content": question_answer[0].strip()},
# {"role": "assistant", "content": question_answer[1].strip()},
# ]
# message += [
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": ""},
# ]
# print(f"message: {message}, batch_size {len(chat_rounds)}")
# messages.append(message)
messages = []
for i in range(len(batch["supervisions"]["cut"])):
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
messages.append(message)
input_ids, attention_mask = preprocess(messages, tokenizer)
if params.enable_speech_output:
generated_ids, generated_speech_output = model.decode_with_speech_output(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
generated_speech_output = [
generated_speech_output
] # WAR: only support batch = 1 for now
for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
speech_file_name = params.log_dir / f"{cut_id}.wav"
# audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
if "CosyVoice2" in params.token2wav_path:
prompt_speech_16k = load_wav(params.prompt_speech_path, 16000)
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
params.prompt_text,
prompt_speech_16k,
token2wav_model,
)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
else:
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
else:
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
print(f"hyps: {hyps}")
return {"beam-search": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
token2wav_model: nn.Module,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
The dataloader.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a dict, whose key may be "beam-search".
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
# questions_with_history = [
# cut.custom["question"] for cut in batch["supervisions"]["cut"]
# ]
# texts = [
# question.split("<USER>: ")[-1].strip()
# for question in questions_with_history
# ]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
token2wav_model=token2wav_model,
batch=batch,
tokenizer=tokenizer,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}")
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.log_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.log_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
params.log_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
logging.info(f"device: {device}")
model, tokenizer = get_model(params, device)
if "CosyVoice2" in params.token2wav_path:
token2wav_model = CosyVoice2(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
else:
token2wav_model = CosyVoice(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
args.return_cuts = True
data_module = AsrDataModule(args)
def remove_long_utt(c: Cut):
# Keep only utterances with duration in 30 seconds
#
if c.duration > 30.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True
# TODO: FIX ME
# test_sets_cuts = data_module.test_cuts_belle()
test_sets_cuts = data_module.test_cuts_en_vocalnet()
test_sets = test_sets_cuts.keys()
test_dls = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
for cuts_name in test_sets
]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
token2wav_model=token2wav_model,
tokenizer=tokenizer,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,256 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
split=test_zh
llm_path=f5-tts/exp_zh/checkpoint-805000
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
vocoder=./bigvgan_v2_24khz_100band_256x
torchrun --nproc_per_node=2 \
f5-tts/infer_dist.py \
--output_dir $output_dir \
--batch_size 1 \
--num_workers 2 \
--llm-model-name-or-path $llm_path \
--flow-matching-model-path $model_path \
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--use-cosyvoice-semantic-token True \
--vocoder-dir $vocoder \
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
"""
import argparse
import json
import os
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
import whisper
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoTokenizer
from web_demo import get_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
# sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="test",
help="huggingface dataset split name",
)
parser.add_argument(
"--subset-name",
type=str,
default="commoneval",
help="subset name",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=2, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=2, help="prefetch for dataloader"
)
parser.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="Checkpoint name or path, default to %(default)r",
)
# parser.add_argument(
# "--top-k",
# type=int,
# default=50,
# help="top k for sampling",
# )
# parser.add_argument(
# "--top-p",
# type=float,
# default=0.95,
# help="top p for sampling",
# )
# parser.add_argument(
# "--temperature",
# type=float,
# default=0.8,
# help="temperature for sampling",
# )
add_model_arguments(parser)
args = parser.parse_args()
return args
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def preprocess(
messages,
tokenizer,
):
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="longest",
truncation=False,
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
def custom_collate(batch):
assert len(batch) == 1
audio = batch[0]["audio"]
assert audio["sampling_rate"] == 16000
result = {"audio": audio["array"]}
for keys in batch[0].keys():
if keys != "audio":
result[keys] = batch[0][keys]
return result
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
dataset = load_dataset(
"hlt-lab/voicebench",
args.subset_name,
split=args.split_name,
trust_remote_code=True,
)
model, tokenizer = get_model(args)
# tokenizer = AutoTokenizer.from_pretrained(args.llm_path_or_name)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=custom_collate,
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
input_ids, attention_mask = preprocess([message], tokenizer)
results_jsonl_file = open(
os.path.join(
args.output_dir,
f"results-{args.subset_name}-{args.split_name}-{rank}-audio.jsonl",
),
"w",
)
for batch in dataloader:
audio = batch["audio"]
audio = torch.from_numpy(audio).to(device).to(torch.float32)
fbank = whisper.log_mel_spectrogram(audio, device=device)
fbank = fbank.unsqueeze(0)
generated_ids = model.decode(
fbank, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
result_dict = {}
for key in batch.keys():
if key != "audio":
result_dict[key] = batch[key]
result_dict["response"] = hyps[0]
json.dump(result_dict, results_jsonl_file)
results_jsonl_file.write("\n")
if rank == 0:
progress_bar.update(world_size * args.batch_size)
if rank == 0:
progress_bar.close()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,310 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
"""
import argparse
import copy
import logging
import os
import random
import sys
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import soundfile as sf
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from cosyvoice.cli.cosyvoice import CosyVoice2
from datasets import Audio, load_dataset
from decode import audio_decode_cosyvoice2
from label_smoothing import LabelSmoothingLoss
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from train import add_model_arguments, add_training_arguments, get_model, get_params
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
get_local_rank,
get_rank,
get_world_size,
setup_logger,
str2bool,
)
# sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
DEFAULT_SPEECH_TOKEN = "<speech>"
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="The batch size to use.",
)
parser.add_argument(
"--split-name",
type=str,
default="test_en",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument(
"--token2wav-path",
type=str,
default="/workspace/CosyVoice-300M-SFT",
help="The path to the token2wav model",
)
add_model_arguments(parser)
add_training_arguments(parser)
return parser
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
# WAR: TODO FIXME check qwen3
# THIS IS THE ONLY DIFFERENCE FROM preprocess
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
target_ids[row, col] = default_speech_token_id
# remove default_speech_token_id from target_ids and input_ids
batch_size = target_ids.size(0)
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def data_collator(batch):
prompt_texts, prompt_speech_16k, messages, ids, target_texts = [], [], [], [], []
for i, item in enumerate(batch):
# speech_tokens.append(item["prompt_audio_cosy2_tokens"])
message_list_item = []
message_list_item += [
{
"role": "user",
"content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": ""},
]
messages.append(message_list_item)
target_texts.append(item["target_text"])
ids.append(item["id"])
prompt_texts.append(item["prompt_text"])
speech_org = item["prompt_audio"]
speech_org = torch.tensor(speech_org["array"], dtype=torch.float32).unsqueeze(0)
speech_org = speech_org.mean(dim=0, keepdim=True)
prompt_speech_16k.append(speech_org)
# resample to 16k
return {
"prompt_texts": prompt_texts,
"target_texts": target_texts,
"prompt_speech_16k": prompt_speech_16k,
"messages": messages,
"ids": ids,
}
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
params.log_dir = Path(params.exp_dir) / "log-results-wav"
params.log_dir.mkdir(parents=True, exist_ok=True)
fix_random_seed(params.seed)
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-decode-tts")
logging.info(params)
logging.info("About to create model")
model, tokenizer = get_model(params)
if torch.cuda.is_available():
device = torch.device("cuda", get_local_rank())
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
dataset = dataset.cast_column("prompt_audio", Audio(sampling_rate=16000))
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
data_loader = DataLoader(
dataset,
batch_size=params.batch_size,
sampler=sampler,
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=data_collator,
)
token2wav_model = CosyVoice2(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
for batch in data_loader:
messages = batch["messages"]
prompt_texts = batch["prompt_texts"]
prompt_speech_16k = batch["prompt_speech_16k"]
target_texts = batch["target_texts"]
ids = batch["ids"]
input_ids, attention_mask, _ = preprocess(messages, tokenizer)
generated_ids, generated_speech_output = model.decode_with_speech_output(
None, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
generated_speech_output = [
generated_speech_output
] # WAR: only support batch = 1 for now
for cut_id, audio_tokens, prompt_text, prompt_speech, target_text in zip(
ids, generated_speech_output, prompt_texts, prompt_speech_16k, target_texts
):
speech_file_name = params.log_dir / f"{cut_id}.wav"
# save target_text to file
with open(params.log_dir / f"{cut_id}.txt", "w") as f:
f.write(f"{target_text}\n")
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
if "CosyVoice2" in params.token2wav_path:
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
prompt_text,
prompt_speech,
token2wav_model,
)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
logging.info("Done!")
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = get_world_size()
rank = get_rank()
torch.set_num_threads(1)
# torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../ASR_LLM/whisper_llm_zh/ds_config_zero1.json

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -0,0 +1,838 @@
from typing import List, Tuple
import torch
from torch import nn
from torchmetrics.classification import MulticlassAccuracy
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import logging
class EncoderProjector(nn.Module):
"""
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
Args:
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
llm_dim (:obj:`int`): The dimension of the language model.
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
"""
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__()
self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(llm_dim, llm_dim)
def forward(self, x):
batch_size, seq_len, feat_dim = x.size()
num_frames_to_discard = seq_len % self.downsample_rate
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
class SPEECH_LLM(nn.Module):
"""
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
The encoder is used to extract speech features from the input speech signal.
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
The language model is used to generate the text from the speech features.
Args:
encoder (:obj:`nn.Module`): The encoder module.
llm (:obj:`nn.Module`): The language model module.
encoder_projector (:obj:`nn.Module`): The encoder projector module.
"""
def __init__(
self,
encoder: nn.Module = None,
llm: nn.Module = None,
encoder_projector: nn.Module = None,
codec_lm: nn.Module = None,
codec_lm_padding_side: str = "left",
teacher_llm: nn.Module = None,
kl_temperature: float = 2.0,
):
super().__init__()
self.encoder = encoder
self.llm = llm
self.encoder_projector = encoder_projector
self.codec_lm = codec_lm
if self.codec_lm:
self.speech_token_projector = nn.Linear(
self.llm.config.hidden_size + self.llm.config.hidden_size,
self.codec_lm.config.hidden_size,
)
self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
)
self.speech_token_projector = self.speech_token_projector.to(
dtype=torch.float16
)
self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16)
self.loss_fct = torch.nn.CrossEntropyLoss()
self.codec_lm_padding_side = codec_lm_padding_side
self.audio_accuracy_metric = MulticlassAccuracy(
self.codec_lm.vocab_size,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=IGNORE_TOKEN_ID,
)
if teacher_llm is not None:
self.teacher_llm = teacher_llm
self.kl_temperature = kl_temperature
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
"""
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
with the speech features and padding the input_ids to the maximum length of the speech features.
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
Args:
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
input_ids (:obj:`torch.Tensor`): The input ids to merge.
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
Returns:
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
"""
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
)
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_speech_tokens.max() * (speech_len - 1)
) + sequence_length
batch_indices, non_speech_indices = torch.where(
input_ids != self.llm.config.default_speech_token_id
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
)
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim),
IGNORE_TOKEN_ID,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_speech_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_speech_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_speech_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_speech_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_speech_indices
]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full(
(batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
)
speech_to_overwrite[batch_indices, text_to_overwrite] = False
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
:, None
].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
)
final_embedding[speech_to_overwrite] = (
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= speech_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(
input_ids == self.llm.config.pad_token_id
)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
return model_outputs.loss, acc
def forward_kl_div(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
teacher_input_ids: torch.LongTensor = None,
teacher_attention_mask: torch.Tensor = None,
teacher_labels: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
)
teacher_outputs = self.teacher_llm(
input_ids=teacher_input_ids,
attention_mask=teacher_attention_mask,
)
kl_loss = torch.nn.functional.kl_div(
torch.nn.functional.log_softmax(
model_outputs.logits[labels != -100] / self.kl_temperature,
dim=-1,
),
torch.nn.functional.softmax(
teacher_outputs.logits[teacher_labels != -100] / self.kl_temperature,
dim=-1,
),
reduction="batchmean",
)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
teacher_preds = torch.argmax(teacher_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
acc_teacher = compute_accuracy(
teacher_preds.detach()[:, :-1],
teacher_labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
return kl_loss, acc, acc_teacher
def forward_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
speech_codec_ids: torch.LongTensor = None,
):
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
if fbank is not None:
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
input_seq_len = attention_mask.sum(dim=1) # shape, B
(
text_label_start_index_list,
text_input_start_index_list,
input_question_len_list,
) = ([], [], [])
for i in range(labels.shape[0]):
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
input_embeds_start_index = input_embeds_valid_index[0]
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
text_labels_start_index = text_labels_valid_index[0]
assert (
input_seq_len[i]
== input_embeds_valid_index[-1] - input_embeds_start_index + 1
), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
assert (
input_embeds_valid_index[-1] == text_labels_valid_index[-1]
), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
input_question_len = text_labels_start_index - input_embeds_start_index
assert (
input_question_len
+ text_labels_valid_index[-1]
- text_labels_start_index
+ 1
== input_seq_len[i]
)
text_label_start_index_list.append(text_labels_start_index)
text_input_start_index_list.append(input_embeds_start_index)
input_question_len_list.append(input_question_len)
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
)
text_loss = model_outputs.loss
delay_step = 1
# prepare codec lm inputs
audio_codes_lens = [
len(x) + input_question_len_list[i] + delay_step + 1
for i, x in enumerate(speech_codec_ids)
]
max_len_speech_codec = max(audio_codes_lens)
if self.codec_lm_padding_side == "right":
audio_codes = [
[self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
elif self.codec_lm_padding_side == "left":
audio_codes = [
[self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
for i, x in enumerate(speech_codec_ids)
]
audio_codes = torch.tensor(
audio_codes, dtype=torch.int64, device=input_ids.device
)
audio_labels = torch.tensor(
audio_labels, dtype=torch.int64, device=input_ids.device
)
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
# text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
text_input_embeds_list = []
for i in range(len(text_label_start_index_list)):
text_last_hidden = model_outputs.hidden_states[-1][
i,
text_input_start_index_list[i] : text_input_start_index_list[i]
+ input_seq_len[i]
- 1,
]
# text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[
i,
text_input_start_index_list[i]
+ 1 : text_input_start_index_list[i]
+ input_seq_len[i],
] # exclude bos
# text_embeds_list.append(text_embed)
text_input_embeds = torch.cat(
[
text_last_hidden,
text_embed,
],
dim=-1,
) # shape, T, D1 + D2
text_input_embeds = self.speech_token_projector(
text_input_embeds
) # shape, T, D_codec
text_input_embeds_list.append(text_input_embeds)
for i in range(audio_embeddings.shape[0]):
text_input_embeds = text_input_embeds_list[i]
if self.codec_lm_padding_side == "right":
audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
elif self.codec_lm_padding_side == "left":
start_idx = torch.where(
audio_codes[i] == self.codec_lm.config.mask_token_id
)[0][0]
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
assert (
start_idx == start_idx_re_compute
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
logging.warning(
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}\naudio_codes_lens: {audio_codes_lens[i]}\ninput_question_len_list: {input_question_len_list[i]}\ninput_seq_len: {input_seq_len[i]}\n"
)
# breakpoint()
text_input_embeds = text_input_embeds[
: audio_embeddings.shape[1] - start_idx
]
audio_embeddings[
i, start_idx : start_idx + text_input_embeds.shape[0]
] += text_input_embeds
speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask,
inputs_embeds=audio_embeddings,
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = speech_outputs.hidden_states[-1].clone()
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
audio_logits = audio_logits.contiguous().view(
-1, self.codec_lm.config.vocab_size
)
audio_labels = audio_labels.contiguous().view(-1)
audio_labels = audio_labels.masked_fill(
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
)
codec_loss = self.loss_fct(audio_logits, audio_labels)
audio_preds = torch.argmax(audio_logits, -1)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
audio_acc = compute_accuracy(
audio_preds.detach(),
audio_labels.detach(),
ignore_label=IGNORE_TOKEN_ID,
)
audio_topk_acc = self.audio_accuracy_metric(
audio_logits.detach(), audio_labels.detach()
).item()
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
def decode(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
**kwargs,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
_,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_new_tokens", 1024),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", True),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 0.5),
top_k=kwargs.get("top_k", 20),
repetition_penalty=kwargs.get("repetition_penalty", 1.1),
temperature=kwargs.get("temperature", 0.7),
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id,
)
return generated_ids
def decode_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024,
max_speech_new_tokens: int = 2048, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]:
"""
Generates text and corresponding speech tokens using the revised logic.
Args:
fbank: Input audio features.
input_ids: Input token IDs for the text prompt.
attention_mask: Attention mask for the text prompt.
max_text_new_tokens: Max new tokens for text generation.
max_speech_new_tokens: Max new tokens for speech generation.
llm_kwargs: Additional arguments for self.llm.generate.
codec_lm_kwargs: Additional arguments for self.codec_lm.generate.
Returns:
Tuple[torch.LongTensor, List[List[int]]]:
- generated_text_ids: Tensor of generated text token IDs (including prompt).
- generated_speech_tokens: List of lists, where each inner list contains
the generated speech codec tokens for a batch item.
"""
batch_size = input_ids.shape[0]
assert batch_size == 1, "Batch size must be 1 for speech generation."
device = next(self.parameters()).device # Use model's device
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
# Merge speech features with prompt embeddings
if fbank is not None:
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
(
merged_prompt_inputs_embeds,
merged_prompt_attention_mask,
_,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, prompt_embeds, input_ids, attention_mask
)
else:
merged_prompt_inputs_embeds = prompt_embeds
merged_prompt_attention_mask = attention_mask
# --- 2. Generate Text using LLM ---
# Use merged embeds/mask as input to generate
# Ensure kwargs passed are suitable for llm.generate
# Note: Using default generation params from `decode` if not provided in kwargs
final_llm_kwargs = {
"bos_token_id": self.llm.config.bos_token_id,
"eos_token_id": self.llm.config.eos_token_id,
"pad_token_id": self.llm.config.pad_token_id,
"num_beams": 1,
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
"top_p": 0.5,
"top_k": 20,
"repetition_penalty": 1.1,
"temperature": 0.7,
**(llm_kwargs or {}), # User-provided kwargs override defaults
}
text_outputs = self.llm.generate(
inputs_embeds=merged_prompt_inputs_embeds,
attention_mask=merged_prompt_attention_mask,
max_new_tokens=max_text_new_tokens,
return_dict_in_generate=True,
output_hidden_states=True,
**final_llm_kwargs,
)
delay_step = 1
generated_text_ids = text_outputs.sequences # [B, S_full]
eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()(
torch.tensor([[eos_token_id]], device=device)
)
assert (
generated_text_ids[0, -1] == eos_token_id
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
thinker_token_embeds_org = [
token_hidden_states[0].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
first_thinker_token_embed = torch.cat(
[
thinker_token_embeds_org[0][:, 1:],
thinker_token_embeds_org[1],
],
dim=1,
)
thinker_token_embeds = (
[first_thinker_token_embed]
+ thinker_token_embeds_org[2:]
+ [eos_token_embedding]
)
thinker_hidden_states = [
token_hidden_states[-1].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
thinker_reply_part = [
torch.cat(
[
thinker_hidden_state,
thinker_token_embed,
],
dim=-1,
)
for thinker_hidden_state, thinker_token_embed in zip(
thinker_hidden_states[1:], thinker_token_embeds[1:]
)
]
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
thinker_prompt_part = torch.cat(
[
thinker_hidden_states[0],
thinker_token_embeds[0],
],
dim=-1,
)
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
thinker_reply_part = self.speech_token_projector(thinker_reply_part)
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step + 1),
self.codec_lm.config.mask_token_id,
dtype=torch.long,
device=self.llm.device,
)
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
thinker_input_embeds = torch.cat(
[
thinker_prompt_part,
thinker_reply_part[:, : delay_step + 1, :],
],
dim=1,
)
talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
past_key_values = None
generated_speech_tokens_list = []
next_token_ids = None
for t in range(max_speech_new_tokens):
if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
next_token_ids
)
if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :]
thinker_reply_part = thinker_reply_part[:, 1:, :]
codec_outputs = self.codec_lm(
inputs_embeds=talker_inputs_embeds,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
output_hidden_states=True,
)
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
next_token_logits = self.codec_lm_head(last_token_hidden_state)
next_token_ids = topk_sampling(
next_token_logits,
)
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break
past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append(
next_token_ids.squeeze(1).cpu().tolist()[0]
)
return generated_text_ids, generated_speech_tokens_list
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float()
def topk_sampling(
logits,
top_k=50,
top_p=0.95,
temperature=0.8,
):
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits_filtered = top_k_top_p_filtering(
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
)
# Sample
probs = torch.nn.functional.softmax(logits_filtered, dim=-1)
tokens = torch.multinomial(probs, num_samples=1)
return tokens
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits

View File

@ -0,0 +1,23 @@
conformer==0.3.2
diffusers==0.29.0
gdown==5.1.0
gradio
hydra-core==1.3.2
HyperPyYAML==1.2.2
inflect==7.3.1
librosa==0.10.2
lightning==2.2.4
matplotlib==3.7.5
#modelscope==1.15.0
networkx==3.1
omegaconf==2.3.0
onnx==1.16.0
onnxruntime-gpu==1.18.0
protobuf==4.25
pydantic==2.7.0
pyworld==0.3.4
rich==13.7.1
soundfile==0.12.1
tensorboard==2.14.0
wget==3.2
WeTextProcessing==1.0.3

View File

@ -0,0 +1,15 @@
openai-whisper
kaldialign
lhotse
sentencepiece
pypinyin
tensorboard
librosa
deepspeed
transformers>=4.37.0
flash-attn
peft
torchmetrics
# triton==3.3.0 # may be violate with openai-whisper
gradio
sherpa-onnx

View File

@ -0,0 +1,131 @@
# server.py
import argparse
import os
from typing import List
import torch
import uvicorn
import whisper
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoTokenizer
from web_demo import get_model
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--prompt-template",
type=str,
default=None,
help="Prompt template",
)
parser.add_argument(
"--port",
type=int,
default=8001,
help="Port number",
)
add_model_arguments(parser)
args = parser.parse_args()
return args
class SpeechRequest(BaseModel):
audio: List[float] # Expecting audio as a list of floats (raw waveform)
sampling_rate: int = 16000
class TextResponse(BaseModel):
text: str
def preprocess_prompt(tokenizer):
"""Preprocesses the prompt template."""
texts = [
tokenizer.apply_chat_template(
message, # Using the hardcoded message
tokenize=True,
add_generation_prompt=False, # Important for generation
chat_template=TEMPLATE,
padding=False, # No padding needed for single prompt
truncation=False,
)
]
input_ids = torch.tensor(texts, dtype=torch.long)
attention_mask = torch.ones_like(
input_ids, dtype=torch.bool
) # Mask is all True for the prompt
return input_ids, attention_mask
args = get_args()
print(f"Using port: {args.port}")
model, tokenizer = get_model(args)
app = FastAPI()
device = torch.device("cuda")
if args.prompt_template is None:
template = f"{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "qa":
template = f"Answer the following question:\n\n{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "continuation":
template = f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "asr":
template = (
f"Repeat the following text, without any explanation: {DEFAULT_SPEECH_TOKEN}"
)
elif args.prompt_template == "mt":
template = f"Please translate the text to Chinese. Your response should only include the Chinese translation, without any additional words:\n\n{DEFAULT_SPEECH_TOKEN}"
else:
raise ValueError(f"Invalid prompt template: {args.prompt_template}")
print("Using template:", template)
message = [
{"role": "user", "content": template},
{"role": "assistant", "content": ""},
]
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
prompt_input_ids, prompt_attention_mask = preprocess_prompt(tokenizer)
prompt_input_ids = prompt_input_ids.to(device)
prompt_attention_mask = prompt_attention_mask.to(device)
@app.post("/decode", response_model=TextResponse)
async def decode_speech(request: SpeechRequest):
"""
Receives audio waveform, processes it, and returns the decoded text.
"""
if request.sampling_rate != 16000:
raise HTTPException(
status_code=400, detail="Only 16kHz sampling rate is supported."
)
try:
audio_tensor = torch.tensor(request.audio, dtype=torch.float32).to(device)
fbank = whisper.log_mel_spectrogram(audio_tensor, device=device, n_mels=80)
fbank = fbank.unsqueeze(0)
with torch.no_grad():
generated_ids = model.decode(fbank, prompt_input_ids, prompt_attention_mask)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
response_text = hyps[0] if hyps else ""
return TextResponse(text=response_text)
except Exception as e:
print(f"Error during processing: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
if __name__ == "__main__":
print("Starting server...")
uvicorn.run(app, host="0.0.0.0", port=args.port)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,604 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from datasets import load_dataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader
from pathlib import Path
from train import add_model_arguments, add_training_arguments, get_params, get_model
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
get_local_rank,
get_rank,
get_world_size,
setup_logger,
str2bool,
)
DEFAULT_SPEECH_TOKEN = "<speech>"
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--batch-size",
type=int,
default=16,
help="The batch size to use.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
add_training_arguments(parser)
return parser
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
# WAR: TODO FIXME check qwen3
# THIS IS THE ONLY DIFFERENCE FROM preprocess
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
target_ids[row, col] = default_speech_token_id
# remove default_speech_token_id from target_ids and input_ids
batch_size = target_ids.size(0)
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def data_collator(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch):
speech_tokens.append(item["code"])
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": item["text"]},
]
# message_list_item += [
# {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": item["text"]},
# ]
messages.append(message_list_item)
durations.append(item["duration"])
ids.append(item["index"] if "index" in item else item["id"])
lang.append(item["language"])
return {
"speech_tokens": speech_tokens,
"messages": messages,
"durations": durations,
"ids": ids,
"lang": lang,
}
def data_collator_ultra_chat(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch):
speech_tokens.append(item["custom"]["speech_token"])
text = item["supervisions"][0]["text"]
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": text},
]
messages.append(message_list_item)
durations.append(item["duration"])
ids.append(item["id"])
return {
"speech_tokens": speech_tokens,
"messages": messages,
"durations": durations,
"ids": ids,
}
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute the loss for the given batch.
Args:
params:
It is returned by :func:`get_params`.
tokenizer:
The tokenizer used to encode the text.
model:
The model for training.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
Whether it is training.
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
device = next(model.parameters()).device
messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"]
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training):
(
text_loss,
acc,
codec_loss,
codec_acc,
codec_topk_acc,
) = model.forward_with_speech_output(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
speech_codec_ids=answer_cosyvoice_speech_token,
)
loss = text_loss + codec_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = len(messages)
# Note: We use reduction=sum while computing the loss.
info["acc"] = acc * len(messages)
info["codec_acc"] = codec_acc * len(messages)
info["codec_topk_acc"] = codec_topk_acc * len(messages)
info["loss"] = loss.detach().cpu().item()
info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
# FIX ME
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
# model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["durations"])
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
tokenizer=tokenizer,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
# model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx != 0:
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"zero-checkpoint-{params.batch_idx_train}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
tag=f"zero-checkpoint-{params.batch_idx_train}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
# sampler_state_dict = train_dl.sampler.state_dict()
sampler_state_dict = train_dl.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
)
try:
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
except: # noqa
raise
if batch_idx % params.log_interval == 0:
try:
cur_lr = scheduler.get_last_lr()[0]
except: # noqa
cur_lr = 0.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_value = tot_loss["loss"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
params.valid_interval = 2000
fix_random_seed(params.seed)
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
model, tokenizer = get_model(params)
if torch.cuda.is_available():
device = torch.device("cuda", get_local_rank())
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
# assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
)
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
if params.dataset == "ultra_chat_voice_assistant":
data_dir = "data/fbank"
json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"]
ds = load_dataset("json", data_files=json_file_lists, split="train")
# shuffle the dataset
train_dataset = ds.shuffle(seed=42)
eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train")
else:
data_dir = Path(params.dataset)
json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")]
ds = load_dataset("json", data_files=json_file_lists, split="train")
# shuffle the dataset
ds = ds.shuffle(seed=42)
train_test_split = ds.train_test_split(test_size=1000, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dl = StatefulDataLoader(
train_dataset,
batch_size=params.batch_size,
sampler=sampler,
shuffle=False,
num_workers=4,
prefetch_factor=2,
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
)
train_dl.load_state_dict(sampler_state_dict)
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
valid_dl = DataLoader(
eval_dataset,
batch_size=params.batch_size,
sampler=valid_sampler,
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
tokenizer=tokenizer,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"zero-epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
# sampler_state_dict = train_dl.sampler.state_dict()
sampler_state_dict = train_dl.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!")
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = get_world_size()
rank = get_rank()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,433 @@
import argparse
import collections
import json
import logging
import os
import pathlib
import random
import re
import subprocess
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
from tqdm import tqdm
import kaldialign
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import numpy as np
Pathlike = Union[str, Path]
def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank():
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.get_rank()
else:
return 0
def get_local_rank():
if "LOCAL_RANK" in os.environ:
return int(os.environ["LOCAL_RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.get_local_rank()
else:
return 0
def str2bool(v):
"""Used in argparse.ArgumentParser.add_argument to indicate
that a type is a bool type and user can enter
- yes, true, t, y, 1, to represent True
- no, false, f, n, 0, to represent False
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
class AttributeDict(dict):
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError(f"No such attribute '{key}'")
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
if key in self:
del self[key]
return
raise AttributeError(f"No such attribute '{key}'")
def __str__(self, indent: int = 2):
tmp = {}
for k, v in self.items():
# PosixPath is ont JSON serializable
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
v = str(v)
tmp[k] = v
return json.dumps(tmp, indent=indent, sort_keys=True)
def setup_logger(
log_filename: Pathlike,
log_level: str = "info",
use_console: bool = True,
) -> None:
"""Setup log level.
Args:
log_filename:
The filename to save the log.
log_level:
The log level to use, e.g., "debug", "info", "warning", "error",
"critical"
use_console:
True to also print logs to console.
"""
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
log_filename = f"{log_filename}-{date_time}-{rank}"
else:
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
log_filename = f"{log_filename}-{date_time}"
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
level = logging.ERROR
if log_level == "debug":
level = logging.DEBUG
elif log_level == "info":
level = logging.INFO
elif log_level == "warning":
level = logging.WARNING
elif log_level == "critical":
level = logging.CRITICAL
logging.basicConfig(
filename=log_filename,
format=formatter,
level=level,
filemode="w",
force=True,
)
if use_console:
console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
if v - v == 0:
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans_frames = ""
ans_utterances = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
if "utt_" not in k:
ans_frames += str(k) + "=" + str(norm_value) + ", "
else:
ans_utterances += str(k) + "=" + str(norm_value)
if k == "utt_duration":
ans_utterances += " frames, "
elif k == "utt_pad_proportion":
ans_utterances += ", "
else:
raise ValueError(f"Unexpected key: {k}")
frames = "%.2f" % self["frames"]
ans_frames += "over " + str(frames) + " frames. "
if ans_utterances != "":
utterances = "%.2f" % self["utterances"]
ans_utterances += "over " + str(utterances) + " utterances."
return ans_frames + ans_utterances
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self["frames"] if "frames" in self else 1
num_utterances = self["utterances"] if "utterances" in self else 1
ans = []
for k, v in self.items():
if k == "frames" or k == "utterances":
continue
norm_value = (
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
)
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
"""Save predicted results and reference transcripts to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
If it is a multi-talker ASR system, the ref and hyp may also be lists of
strings.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp in texts:
if char_level:
ref = list("".join(ref))
hyp = list("".join(hyp))
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
compute_CER: bool = False,
sclite_mode: bool = False,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cut_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)

View File

@ -0,0 +1,434 @@
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
import io
import sys
from argparse import ArgumentParser
import gradio as gr
import gradio.processing_utils as processing_utils
import numpy as np
import sherpa_onnx
import soundfile as sf
import torch
import whisper
#from cosyvoice.cli.cosyvoice import CosyVoice
from gradio_client import utils as client_utils
from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_model(params, device="cuda"):
"""Load and prepare the speech-to-speech model."""
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "eager"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch.float16,
)
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
# codec_vocab_size = 4096 + 4
codec_vocab_size = 6561 + 4
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=2048,
max_position_embeddings=4096,
)
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch.float16,
)
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side="left" if params.use_flash_attn else "right",
)
checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
return model, tokenizer
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
"""
Generate audio from tokens with optional tone and prompt embedding.
Args:
audio_tokens (list): List of audio tokens to be processed.
codec_decoder: Codec decoder for generating audio.
Returns:
torch.Tensor: Generated audio waveform.
"""
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
prompt_speech_feat = torch.zeros(1, 0, 80)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
prompt_token_len=torch.tensor(
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
prompt_feat_len=torch.tensor(
[prompt_speech_feat.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
embedding=flow_embedding.to(codec_decoder.model.device),
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def preprocess(
messages,
tokenizer,
):
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="longest",
truncation=False,
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def format_history(history: list):
messages = []
for item in history:
if isinstance(item["content"], str):
messages.append({"role": item["role"], "content": item["content"]})
return messages
def decode(
model,
token2wav_model,
tokenizer,
feature,
messages,
):
"""Decode one
Returns:
pass
"""
dtype = torch.float32
device = model.llm.device
feature = feature.to(device, dtype=dtype)
input_ids, attention_mask = preprocess([messages], tokenizer)
generated_ids, audio_tokens = model.decode_with_speech_output(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
yield {"type": "text", "data": hyps[0]}
audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
audio = audio_hat.squeeze(0).cpu().numpy()
audio = np.array(audio * 32767).astype(np.int16)
wav_io = io.BytesIO()
sf.write(wav_io, audio, samplerate=22050, format="WAV")
wav_io.seek(0)
wav_bytes = wav_io.getvalue()
audio_path = processing_utils.save_bytes_to_cache(
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
)
yield {"type": "audio", "data": audio_path}
def media_predict(audio, history):
# First yield
yield (
None, # microphone
history, # media_chatbot
gr.update(visible=False), # submit_btn
gr.update(visible=True), # stop_btn
)
print(2333, history, audio)
history.append({"role": "user", "content": (audio,)})
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
history.append({"role": "assistant", "content": ""})
formatted_history = format_history(
history=history
) # only keep string text format
assert audio is not None
audio_transcript = get_transcript(
audio,
asr_model,
)
history[-2]["content"] = audio_transcript
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
fbank = fbank.unsqueeze(0)
assert fbank.ndim == 3
for chunk in decode(
model, token2wav_model, tokenizer, fbank, formatted_history
):
if chunk["type"] == "text":
history[-1]["content"] = chunk["data"]
yield (
None, # microphone
history, # media_chatbot
gr.update(visible=False), # submit_btn
gr.update(visible=True), # stop_btn
)
if chunk["type"] == "audio":
history.append(
{"role": "assistant", "content": gr.Audio(chunk["data"])}
)
# Final yield
yield (
None, # microphone
history, # media_chatbot
gr.update(visible=True), # submit_btn
gr.update(visible=False), # stop_btn
)
with gr.Blocks() as demo:
with gr.Tab("Online"):
with gr.Row():
with gr.Column(scale=1):
microphone = gr.Audio(sources=["microphone"], type="filepath")
submit_btn = gr.Button("Submit", variant="primary")
stop_btn = gr.Button("Stop", visible=False)
clear_btn = gr.Button("Clear History")
with gr.Column(scale=2):
media_chatbot = gr.Chatbot(height=650, type="messages")
def clear_history():
return [], gr.update(value=None)
submit_event = submit_btn.click(
fn=media_predict,
inputs=[
microphone,
media_chatbot,
],
outputs=[microphone, media_chatbot, submit_btn, stop_btn],
)
stop_btn.click(
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
inputs=None,
outputs=[submit_btn, stop_btn],
cancels=[submit_event],
queue=False,
)
clear_btn.click(
fn=clear_history, inputs=None, outputs=[media_chatbot, microphone]
)
demo.queue(default_concurrency_limit=100, max_size=100).launch(
max_threads=100,
ssr_mode=False,
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
def _get_args():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--token2wav-path",
type=str,
default=None,
help="Token2Wav path, default to %(default)r",
)
parser.add_argument(
"--asr-model-dir",
type=str,
default=None,
help="ASR model dir, default to %(default)r",
)
parser.add_argument(
"--flash-attn2",
action="store_true",
default=False,
help="Enable flash_attention_2 when loading the model.",
)
parser.add_argument(
"--share",
action="store_true",
default=False,
help="Create a publicly shareable link for the interface.",
)
parser.add_argument(
"--inbrowser",
action="store_true",
default=False,
help="Automatically launch the interface in a new tab on the default browser.",
)
parser.add_argument(
"--server-port", type=int, default=8001, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
)
add_model_arguments(parser)
args = parser.parse_args()
return args
def read_wave(wave_filename: str):
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and can be of type
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples,
which are normalized to the range [-1, 1].
- Sample rate of the wave file.
"""
samples, sample_rate = sf.read(wave_filename, dtype="float32")
assert (
samples.ndim == 1
), f"Expected single channel, but got {samples.ndim} channels."
samples_float32 = samples.astype(np.float32)
return samples_float32, sample_rate
def get_transcript(audio_path, recognizer):
samples, sample_rate = read_wave(audio_path)
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
recognizer.decode_streams([s])
return s.result.text
if __name__ == "__main__":
args = _get_args()
model, tokenizer = get_model(args)
token2wav = CosyVoice(
args.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
tokens=f"{args.asr_model_dir}/tokens.txt",
num_threads=2,
sample_rate=16000,
feature_dim=80,
decoding_method="greedy_search",
debug=False,
)
_launch_demo(args, model, tokenizer, token2wav, asr_model)

View File

@ -0,0 +1 @@
../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py