mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 559f9e2deff33077461428d422d9f03c95988b01 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
a5de488304
55
egs/speech_llm/SPEECH2SPEECH/README.md
Normal file
55
egs/speech_llm/SPEECH2SPEECH/README.md
Normal file
@ -0,0 +1,55 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes scripts for training speech2speech models.
|
||||
|
||||
# SPEECH2SPEECH
|
||||
|
||||
The following table lists the folders for different tasks.
|
||||
|
||||
|Recipe | Speech Input | Speech Output | Comment|
|
||||
|--------------|--------------|---------------|--------|
|
||||
|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token |
|
||||
|
||||
### [Qwen-omni like Speech2speech Recipe](./qwen_omni)
|
||||
|
||||
[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset.
|
||||
|
||||
<br>
|
||||
<p align="center">
|
||||
<img src="assets/framework.png" width="800"/>
|
||||
<p>
|
||||
<br>
|
||||
|
||||
Command for training is:
|
||||
```bash
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
```
|
||||
|
||||
Command for decoding is:
|
||||
```bash
|
||||
python3 ./qwen_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method e2e-epoch10_speech2speech \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path models/CosyVoice-300M-SFT \
|
||||
--use-lora True
|
||||
```
|
||||
|
||||
Please see [`prepare.sh`](./prepare.sh) for more details.
|
BIN
egs/speech_llm/SPEECH2SPEECH/assets/framework.png
Normal file
BIN
egs/speech_llm/SPEECH2SPEECH/assets/framework.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 101 KiB |
234
egs/speech_llm/SPEECH2SPEECH/exp.sh
Normal file
234
egs/speech_llm/SPEECH2SPEECH/exp.sh
Normal file
@ -0,0 +1,234 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
# export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
||||
set -eou pipefail
|
||||
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
|
||||
if [ ! -L "/workspace/slam" ]; then
|
||||
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||
fi
|
||||
log "stage 17: Training Speech2Speech Model, full parameters"
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
|
||||
pretrained_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=4
|
||||
|
||||
latest_checkpoint_step=-1
|
||||
# Check if exp_dir exists and is a directory
|
||||
if [ -d "$exp_dir" ]; then
|
||||
# List directories matching checkpoint-* and find the one with the largest step number
|
||||
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||
# Extract step number using parameter expansion
|
||||
current_step=${checkpoint_name#checkpoint-}
|
||||
# Ensure current_step is a number
|
||||
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||
latest_checkpoint_step=$current_step
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
train_cmd_args="--max-duration 200 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--on-the-fly-feats True --on-the-fly-speed-perturb False\
|
||||
--deepspeed \
|
||||
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True --on-the-fly-feats True \
|
||||
--dataset vocalnet_ultrachat_voiceassistant_instruct_s2s --num-epochs 10 \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False"
|
||||
|
||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||
step=$latest_checkpoint_step
|
||||
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||
else
|
||||
log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||
# No pretrained model or sampler state dict needed for the first run
|
||||
fi
|
||||
|
||||
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
|
||||
$train_cmd_args
|
||||
fi
|
||||
|
||||
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
|
||||
# check if the link exists, if not exist, create it
|
||||
if [ ! -L "/workspace/slam" ]; then
|
||||
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||
fi
|
||||
log "stage 17: Training Speech2Speech Model, full parameters"
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
|
||||
pretrained_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=4
|
||||
|
||||
latest_checkpoint_step=-1
|
||||
# Check if exp_dir exists and is a directory
|
||||
if [ -d "$exp_dir" ]; then
|
||||
# List directories matching checkpoint-* and find the one with the largest step number
|
||||
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||
# Extract step number using parameter expansion
|
||||
current_step=${checkpoint_name#checkpoint-}
|
||||
# Ensure current_step is a number
|
||||
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||
latest_checkpoint_step=$current_step
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
train_cmd_args="--max-duration 200 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--on-the-fly-feats True --on-the-fly-speed-perturb False\
|
||||
--deepspeed \
|
||||
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True --on-the-fly-feats True \
|
||||
--dataset vocalnet_ultrachat_voiceassistant_instruct_s2s_librispeech --num-epochs 10 \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False"
|
||||
|
||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||
step=$latest_checkpoint_step
|
||||
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||
else
|
||||
log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||
# No pretrained model or sampler state dict needed for the first run
|
||||
fi
|
||||
|
||||
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
|
||||
$train_cmd_args
|
||||
fi
|
||||
|
||||
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
log "stage 19: Training TTS Model"
|
||||
exp_dir=./qwen_omni/exp_tts_ultra_chat_voice_assistant
|
||||
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_only_template
|
||||
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_three_concat
|
||||
pretrained_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=4
|
||||
|
||||
latest_checkpoint_step=-1
|
||||
# Check if exp_dir exists and is a directory
|
||||
if [ -d "$exp_dir" ]; then
|
||||
# List directories matching checkpoint-* and find the one with the largest step number
|
||||
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||
# Extract step number using parameter expansion
|
||||
current_step=${checkpoint_name#checkpoint-}
|
||||
# Ensure current_step is a number
|
||||
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||
latest_checkpoint_step=$current_step
|
||||
fi
|
||||
done
|
||||
fi
|
||||
# --dataset ultra_chat_voice_assistant
|
||||
train_cmd_args="--batch-size 30 \
|
||||
--exp-dir $exp_dir \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--enable-speech-input False \
|
||||
--deepspeed \
|
||||
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/VoxBox/manifests_emilia_en \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--num-epochs 3 \
|
||||
--use-lora False --unfreeze-llm False --enable-speech-output True"
|
||||
|
||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||
step=$latest_checkpoint_step
|
||||
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||
else
|
||||
log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||
# No pretrained model or sampler state dict needed for the first run
|
||||
fi
|
||||
|
||||
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train_tts.py \
|
||||
$train_cmd_args
|
||||
fi
|
||||
|
||||
|
||||
# if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
# log "stage 20: Training TTS Model"
|
||||
# echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
|
||||
# if [ ! -L "/workspace/slam" ]; then
|
||||
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||
# fi
|
||||
# exp_dir=./qwen_omni/exp_test
|
||||
# ngpu=4
|
||||
|
||||
# latest_checkpoint_step=-1
|
||||
# # Check if exp_dir exists and is a directory
|
||||
# if [ -d "$exp_dir" ]; then
|
||||
# # List directories matching checkpoint-* and find the one with the largest step number
|
||||
# for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||
# checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||
# # Extract step number using parameter expansion
|
||||
# current_step=${checkpoint_name#checkpoint-}
|
||||
# # Ensure current_step is a number
|
||||
# if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||
# latest_checkpoint_step=$current_step
|
||||
# fi
|
||||
# done
|
||||
# fi
|
||||
|
||||
# train_cmd_args="--max-duration 150 \
|
||||
# --enable-musan False \
|
||||
# --exp-dir $exp_dir \
|
||||
# --speech-encoder-path-or-name models/large-v2.pt \
|
||||
# --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
# --dataset vocalnet_ultrachat_voiceassistant \
|
||||
# --manifest-dir data/fbank \
|
||||
# --deepspeed \
|
||||
# --deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
# --use-flash-attn True --on-the-fly-feats True \
|
||||
# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
|
||||
|
||||
# if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
# log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||
# step=$latest_checkpoint_step
|
||||
# train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||
# else
|
||||
# log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||
# # No pretrained model or sampler state dict needed for the first run
|
||||
# fi
|
||||
|
||||
# torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
|
||||
# $train_cmd_args
|
||||
# fi
|
||||
|
||||
|
||||
# if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||
# log "stage 21: TTS Decoding Test Set"
|
||||
# exp_dir=./qwen_omni/exp_tts
|
||||
# torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
|
||||
# --exp-dir $exp_dir \
|
||||
# --speech-encoder-path-or-name models/large-v2.pt \
|
||||
# --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
# --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
|
||||
# --use-flash-attn True \
|
||||
# --enable-speech-output True \
|
||||
# --token2wav-path /workspace/CosyVoice2-0.5B \
|
||||
# --use-lora True
|
||||
# fi
|
291
egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py
Executable file
291
egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py
Executable file
@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
|
||||
# Copyright 2025 Nvidia (Yuekai Zhang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
--out-dir data/fbank \
|
||||
--huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
|
||||
--audio-key question_audio --text-key answer \
|
||||
--prefix ultrachat
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig
|
||||
from vocalnet_lhotse_cutset import LazyCustomDatasetIterator
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resample-to-16kHz",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Resample audio to 16kHz. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed-perturb",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="Output directory for the computed features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--huggingface-dataset-path-or-name",
|
||||
type=str,
|
||||
default="/workspace/Belle_1.4M-SLAM-Omni",
|
||||
help="The path or name of the Huggingface dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-key",
|
||||
type=str,
|
||||
default="question_audio",
|
||||
help="The key in the Huggingface dataset containing the audio data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-key",
|
||||
type=str,
|
||||
default="answer",
|
||||
help="The key in the Huggingface dataset containing the text data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="belle",
|
||||
help="""The dataset prefix to use when saving the features""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-file-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the json file containing the vocalnet data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--drop-recordings",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Drop recordings. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The subset to use from the Huggingface dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=str,
|
||||
default="train",
|
||||
help="The split to use from the Huggingface dataset",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def remove_short_and_long_utt(c):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 50.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def compute_fbank(args):
|
||||
in_out_dir = Path(args.out_dir)
|
||||
in_out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# number of workers in dataloader
|
||||
num_workers = 4
|
||||
|
||||
# number of seconds in a batch
|
||||
batch_duration = 10
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Only WhisperFbank is implemented.")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
dataset = load_dataset(
|
||||
args.huggingface_dataset_path_or_name,
|
||||
args.subset,
|
||||
streaming=True,
|
||||
split=args.split,
|
||||
)
|
||||
num_shards = dataset.num_shards
|
||||
num_digits = 5
|
||||
for i in range(252, num_shards):
|
||||
shard = dataset.shard(num_shards, i)
|
||||
# shard = shard.take(10) # for testing
|
||||
logging.info(
|
||||
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
|
||||
)
|
||||
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
|
||||
cut_set = CutSet.from_huggingface_dataset(
|
||||
shard, audio_key=args.audio_key, text_key=args.text_key
|
||||
)
|
||||
|
||||
cut_set = cut_set.filter(remove_short_and_long_utt)
|
||||
if args.resample_to_16kHz:
|
||||
cut_set = cut_set.resample(16000)
|
||||
if args.speed_perturb:
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{in_out_dir}/feats_{idx}_{args.subset}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
# cut_set = cut_set.trim_to_supervisions(
|
||||
# keep_overlapping=False, min_duration=None
|
||||
# )
|
||||
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.{args.subset}.jsonl.gz"
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
||||
if args.drop_recordings:
|
||||
cut_set.drop_recordings().to_file(cuts_path)
|
||||
else:
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
||||
def compute_fbank_vocalnet(args):
|
||||
in_out_dir = Path(args.out_dir)
|
||||
in_out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# number of workers in dataloader
|
||||
num_workers = 4
|
||||
|
||||
# number of seconds in a batch
|
||||
batch_duration = 10
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Only WhisperFbank is implemented.")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
num_shards = 50
|
||||
num_digits = 5
|
||||
for i in range(num_shards):
|
||||
logging.info(f"Processing shard {i}")
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
cut_set = CutSet(
|
||||
LazyCustomDatasetIterator(
|
||||
json_file_path=args.json_file_path, shard_id=i, num_shards=num_shards
|
||||
)
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
if args.resample_to_16kHz:
|
||||
cut_set = cut_set.resample(16000)
|
||||
if args.speed_perturb:
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{in_out_dir}/feats_{idx}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz"
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
if args.json_file_path is not None:
|
||||
compute_fbank_vocalnet(args)
|
||||
else:
|
||||
compute_fbank(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
99
egs/speech_llm/SPEECH2SPEECH/local/vocalnet_lhotse_cutset.py
Normal file
99
egs/speech_llm/SPEECH2SPEECH/local/vocalnet_lhotse_cutset.py
Normal file
@ -0,0 +1,99 @@
|
||||
# https://huggingface.co/datasets/VocalNet/UltraChat-vocalnet/blob/main/UltraChat.json
|
||||
# https://huggingface.co/datasets/VocalNet/VoiceAssistant-430K-vocalnet/blob/main/VoiceAssistant-430K.json
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from lhotse import CutSet
|
||||
from lhotse.audio import Recording
|
||||
from lhotse.supervision import SupervisionSegment
|
||||
|
||||
|
||||
class LazyCustomDatasetIterator:
|
||||
"""
|
||||
Thin wrapper on top of HF datasets objects that allows to interact with them through a Lhotse CutSet.
|
||||
It can be initialized with an existing HF dataset, or args/kwargs passed on to ``datasets.load_dataset()``.
|
||||
Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples
|
||||
returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively.
|
||||
The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary.
|
||||
Example with existing HF dataset::
|
||||
>>> import datasets
|
||||
... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")
|
||||
... dataset = dataset.map(some_transform)
|
||||
... cuts_it = LazyHFDatasetIterator(dataset)
|
||||
... for cut in cuts_it:
|
||||
... pass
|
||||
Example providing HF dataset init args/kwargs::
|
||||
>>> import datasets
|
||||
... cuts_it = LazyHFDatasetIterator("mozilla-foundation/common_voice_11_0", "hi", split="test")
|
||||
... for cut in cuts_it:
|
||||
... pass
|
||||
"""
|
||||
|
||||
def __init__(self, json_file_path: str, shard_id: int = 0, num_shards: int = 100):
|
||||
self.json_file_path = json_file_path
|
||||
self.shard_id = shard_id
|
||||
self.num_shards = num_shards
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
with open(self.json_file_path, "r", encoding="utf-8") as f:
|
||||
list_data_dict = json.load(f)
|
||||
list_data_dict = list_data_dict[self.shard_id :: self.num_shards]
|
||||
for item in list_data_dict:
|
||||
custom_data = item.copy()
|
||||
json_file_parent_of_parent_dir = os.path.dirname(
|
||||
os.path.dirname(self.json_file_path)
|
||||
)
|
||||
units_path = os.path.join(
|
||||
json_file_parent_of_parent_dir, custom_data["units"]
|
||||
)
|
||||
speech_token_dict = np.load(units_path, allow_pickle=True).item()
|
||||
speech_token = speech_token_dict["speech_token"].squeeze(0).tolist()
|
||||
speech_token_len = speech_token_dict["speech_token_len"]
|
||||
|
||||
assert len(speech_token) == speech_token_len
|
||||
custom_data["speech_token"] = speech_token
|
||||
audio_path = custom_data.pop("speech", None)
|
||||
audio_path = os.path.join(json_file_parent_of_parent_dir, audio_path)
|
||||
item_id = item.get("id")
|
||||
recording = Recording.from_file(path=audio_path, recording_id=item_id)
|
||||
|
||||
conversations = item.get("conversations")
|
||||
assert isinstance(conversations, list) and len(conversations) == 2
|
||||
for conv in conversations:
|
||||
if isinstance(conv, dict) and conv.get("from") == "gpt":
|
||||
gpt_text = conv.get("value")
|
||||
break
|
||||
assert gpt_text is not None
|
||||
|
||||
supervision = SupervisionSegment(
|
||||
id=item_id,
|
||||
recording_id=recording.id,
|
||||
start=0.0, # Assuming the supervision covers the entire recording
|
||||
duration=recording.duration,
|
||||
text=gpt_text,
|
||||
)
|
||||
|
||||
cut = recording.to_cut()
|
||||
# cut.id will be the same as recording.id
|
||||
|
||||
cut.supervisions = [supervision]
|
||||
# custom_data contains the original item's fields, minus "speech".
|
||||
# So, "id", "conversations", "units", etc., are preserved here.
|
||||
custom_data.pop("conversations")
|
||||
custom_data.pop("units")
|
||||
cut.custom = custom_data
|
||||
|
||||
yield cut
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
json_file_path = (
|
||||
"/workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json"
|
||||
)
|
||||
cut_set = CutSet(LazyCustomDatasetIterator(json_file_path=json_file_path))
|
||||
|
||||
for cut in cut_set:
|
||||
print(cut)
|
||||
input()
|
444
egs/speech_llm/SPEECH2SPEECH/prepare.sh
Normal file
444
egs/speech_llm/SPEECH2SPEECH/prepare.sh
Normal file
@ -0,0 +1,444 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: Clone CosyVoice repo and install requirements inside the container"
|
||||
# docker: ghcr.io/swivid/f5-tts:main
|
||||
pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
|
||||
cd /workspace/CosyVoice
|
||||
# If you failed to clone submodule due to network failures, please run following command until success
|
||||
git submodule update --init --recursive
|
||||
pip install -r qwen_omni/requirements.txt
|
||||
pip install -r qwen_omni/requirements-cosyvoice.txt
|
||||
|
||||
# For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
# Cosyvoice pretrained model for speech token2wav module
|
||||
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
|
||||
# Qwen Pretrained model
|
||||
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
|
||||
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
|
||||
|
||||
# For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
|
||||
pip install sherpa-onnx
|
||||
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
||||
if [ ! -d $model_path ]; then
|
||||
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
||||
fi
|
||||
fi
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "stage 1: Compute fbank feature from huggingface"
|
||||
python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
--out-dir data/fbank_test \
|
||||
--huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
|
||||
--audio-key question_audio --text-key answer \
|
||||
--prefix belle
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Combine features"
|
||||
manifest_dir=data/fbank
|
||||
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
||||
mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
|
||||
# exclude cust_belle_00000.jsonl.gz for valid and test set
|
||||
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
||||
echo $pieces | wc
|
||||
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
|
||||
mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
|
||||
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
|
||||
ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
|
||||
fi
|
||||
fi
|
||||
|
||||
ngpu=8
|
||||
exp_dir=./qwen_omni/exp_speech2speech
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: Training Speech2Speech Model"
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "stage 4: Decoding, only support batch_size=1 for now."
|
||||
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||
python3 ./qwen_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method e2e-epoch10_speech2speech \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path models/CosyVoice-300M-SFT \
|
||||
--use-lora True
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "stage 5: Gradio Demo"
|
||||
python3 ./qwen_omni/web_demo.py \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-999.pt \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
|
||||
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "stage 6: Compute fbank feature from huggingface"
|
||||
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
||||
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
# --out-dir data/fbank_voice_assistant \
|
||||
# --huggingface-dataset-path-or-name worstchan/VoiceAssistant-400K-SLAM-Omni \
|
||||
# --audio-key question_audio --text-key answer \
|
||||
# --prefix voice_assistant
|
||||
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
--out-dir data/fbank_voice_assistant_cosy2 \
|
||||
--json-file-path /workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json \
|
||||
--prefix voice_assistant
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "stage 7: Compute fbank feature from huggingface"
|
||||
# CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
|
||||
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
# --out-dir data/fbank_ultrachat \
|
||||
# --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
|
||||
# --audio-key question_audio --text-key answer \
|
||||
# --prefix ultrachat
|
||||
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
--out-dir data/fbank_ultrachat_cosy2 \
|
||||
--json-file-path /workspace/slam/UltraChat-vocalnet/UltraChat.json \
|
||||
--prefix ultrachat
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "stage 8: Compute fbank feature from huggingface"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
--out-dir data/fbank_gigaspeech \
|
||||
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
|
||||
--subset test --split test \
|
||||
--audio-key audio --text-key text \
|
||||
--prefix gigaspeech
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
|
||||
--out-dir data/fbank_gigaspeech \
|
||||
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
|
||||
--subset xl --split train \
|
||||
--audio-key audio --text-key text \
|
||||
--prefix gigaspeech
|
||||
fi
|
||||
|
||||
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||
ngpu=4
|
||||
exp_dir=./qwen_omni/exp_speech2speech_en
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "stage 10: Training Speech2Speech Model"
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 150 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset-format vocalnet \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True --on-the-fly-feats True \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "stage 11: Decoding EN, val set only support batch_size=1 for now."
|
||||
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||
python3 ./qwen_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 997 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method e2e-epoch4_speech2speech \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||
--use-lora True
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "stage 12: Decoding EN voicebench"
|
||||
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||
torchrun --nproc_per_node=2 \
|
||||
./qwen_omni/decode_dist.py \
|
||||
--output-dir $exp_dir/log_voicebench \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
|
||||
--use-lora True --subset-name openbookqa --split-name test
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
log "stage 13: Server"
|
||||
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||
python3 ./qwen_omni/server.py \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--use-lora True
|
||||
fi
|
||||
|
||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
log "stage 14: Client"
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
|
||||
# exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
|
||||
# The final assignment of datasets in the original script is used here:
|
||||
# (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
|
||||
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu")
|
||||
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh")
|
||||
declare -a target_datasets=("mmsu")
|
||||
|
||||
NUM_CLIENT_JOBS=4 # Number of parallel client jobs
|
||||
BASE_PORT=8000 # Base port for servers
|
||||
|
||||
log "Starting $NUM_CLIENT_JOBS parallel client jobs to process ${#target_datasets[@]} datasets."
|
||||
|
||||
for job_id in $(seq 0 $(($NUM_CLIENT_JOBS - 1)))
|
||||
do
|
||||
( # Start a subshell for backgrounding this client job's tasks
|
||||
current_port=$(expr $BASE_PORT + $job_id)
|
||||
log "Client Job $job_id: Initializing. Will connect to port $current_port."
|
||||
|
||||
processed_count_for_this_job=0
|
||||
# Iterate over all datasets using their indices
|
||||
for i in "${!target_datasets[@]}"; do
|
||||
# Assign dataset to job_id in a round-robin fashion
|
||||
if [ $(($i % $NUM_CLIENT_JOBS)) -eq $job_id ]; then
|
||||
dataset="${target_datasets[$i]}"
|
||||
|
||||
# local split_name # Determine split_name based on dataset
|
||||
if [ "$dataset" == "sd-qa" ]; then
|
||||
split_name="usa"
|
||||
else
|
||||
split_name="test"
|
||||
fi
|
||||
|
||||
log "Client Job $job_id (Port $current_port): Processing dataset '$dataset' (split '$split_name')"
|
||||
python3 ./qwen_omni/client.py \
|
||||
--subset-name "$dataset" \
|
||||
--split-name "$split_name" \
|
||||
--output-dir "$exp_dir/results" \
|
||||
--port "$current_port" # Assuming client.py accepts --port
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
log "Client Job $job_id (Port $current_port): ERROR processing dataset '$dataset'."
|
||||
fi
|
||||
processed_count_for_this_job=$(($processed_count_for_this_job + 1))
|
||||
fi
|
||||
done
|
||||
log "Client Job $job_id (Port $current_port): Finished. Processed $processed_count_for_this_job datasets."
|
||||
) & # Run this client job's subshell in the background
|
||||
done
|
||||
|
||||
log "All client jobs launched. Waiting for completion..."
|
||||
wait # Wait for all backgrounded client jobs to complete
|
||||
log "All client jobs have completed."
|
||||
fi
|
||||
|
||||
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
log "stage 15: Training Speech2Speech Model, adaptor only"
|
||||
exp_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=2
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 700 \
|
||||
--enable-musan False \
|
||||
--audio-key audio --text-key continuation \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--on-the-fly-feats True \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--dataset-format speech_continuation \
|
||||
--start-epoch 4 --pretrained-model-path $exp_dir/epoch-3/pytorch_model.bin \
|
||||
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
||||
fi
|
||||
|
||||
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
|
||||
log "stage 16: Training Speech2Speech Model, adaptor only"
|
||||
exp_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=4
|
||||
|
||||
latest_checkpoint_step=-1
|
||||
# Check if exp_dir exists and is a directory
|
||||
if [ -d "$exp_dir" ]; then
|
||||
# List directories matching checkpoint-* and find the one with the largest step number
|
||||
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||
# Extract step number using parameter expansion
|
||||
current_step=${checkpoint_name#checkpoint-}
|
||||
# Ensure current_step is a number
|
||||
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||
latest_checkpoint_step=$current_step
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
train_cmd_args="--max-duration 800 \
|
||||
--enable-musan False \
|
||||
--audio-key audio --text-key continuation \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--on-the-fly-feats True \
|
||||
--deepspeed \
|
||||
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--dataset-format speech_continuation \
|
||||
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False"
|
||||
|
||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||
step=$latest_checkpoint_step
|
||||
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||
else
|
||||
log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||
# No pretrained model or sampler state dict needed for the first run
|
||||
fi
|
||||
|
||||
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
|
||||
$train_cmd_args
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||
# pip install gradio sherpa-onnx
|
||||
log "stage 17: Server for adapter only speech continuation"
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s_librispeech
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_three_s2s
|
||||
|
||||
N_GPUS=4 # Define the number of GPUs/processes you want to launch
|
||||
|
||||
for id in $(seq 0 $(($N_GPUS - 1)))
|
||||
do
|
||||
log "Launching server on GPU $id with port $(expr 8000 + $id)"
|
||||
CUDA_VISIBLE_DEVICES=$id python3 ./qwen_omni/server.py \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/checkpoint-55276/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output False \
|
||||
--port $(expr 18000 + $id) \
|
||||
--use-lora True &
|
||||
done
|
||||
|
||||
wait # Wait for all background processes to complete
|
||||
fi
|
||||
|
||||
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||
log "stage 18: Training kl-div Speech2Speech Model, adaptor only"
|
||||
exp_dir=./qwen_omni/exp_speech2text_kl
|
||||
ngpu=2
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 700 \
|
||||
--enable-musan False \
|
||||
--audio-key audio --text-key continuation \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--on-the-fly-feats True \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--dataset-format speech_continuation \
|
||||
--loss-type kl_div --dataset librispeech \
|
||||
--pretrained-model-path $exp_dir/checkpoint-1001/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-1001/sampler.pt \
|
||||
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
||||
fi
|
||||
|
||||
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
log "stage 19: Server for kl loss"
|
||||
exp_dir=./qwen_omni/exp_speech2text_kl
|
||||
python3 ./qwen_omni/server.py \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output False \
|
||||
--use-lora False --prompt-template qa
|
||||
fi
|
||||
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "stage 20: Training Speech2Speech Model, adaptor + lora, second stage"
|
||||
exp_dir=./qwen_omni/exp_speech2text_kl_llm
|
||||
pretrained_dir=./qwen_omni/exp_speech2text_kl
|
||||
ngpu=2
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 200 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--pretrained-model-path $pretrained_dir/epoch-10/pytorch_model.bin \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False --dataset-format vocalnet
|
||||
fi
|
156
egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py
Normal file
156
egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py
Normal file
@ -0,0 +1,156 @@
|
||||
# client.py
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Speech-to-Text Client")
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="http://localhost",
|
||||
help="URL of the FastAPI server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port of the FastAPI server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="hlt-lab/voicebench",
|
||||
help="Hugging Face dataset name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset-name",
|
||||
type=str,
|
||||
default="commoneval", # Adjust as needed
|
||||
help="Dataset subset name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default=None, # Adjust as needed
|
||||
help="Dataset split name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=str, help="Directory to save results"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
output_filename = os.path.join(
|
||||
args.output_dir,
|
||||
f"{args.subset_name}-{args.split_name}.jsonl",
|
||||
)
|
||||
server_decode_url = f"{args.server_url}:{args.port}/decode"
|
||||
|
||||
print("Loading dataset...")
|
||||
if args.subset_name != "mmsu":
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.subset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
# load all splits and concatenate them
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.subset_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
dataset = concatenate_datasets([dataset[subset] for subset in dataset])
|
||||
|
||||
print(f"Dataset loaded with {len(dataset)} samples.")
|
||||
print(f"Sending requests to {server_decode_url}...")
|
||||
print(f"Saving results to {output_filename}")
|
||||
|
||||
with open(output_filename, "w", encoding="utf-8") as outfile:
|
||||
# Iterate directly over the dataset
|
||||
progress_bar = tqdm(dataset, desc="Processing", unit="samples")
|
||||
for item in progress_bar:
|
||||
|
||||
audio_info = item.get("audio")
|
||||
assert (
|
||||
audio_info["sampling_rate"] == 16000
|
||||
), f"Sampling rate is {audio_info['sampling_rate']}, not 16khz"
|
||||
|
||||
# Prepare data for JSON serialization and server request
|
||||
audio_array = audio_info["array"].tolist() # Convert numpy array to list
|
||||
result_dict = {}
|
||||
for key in item.keys():
|
||||
if key != "audio":
|
||||
# Ensure other fields are JSON serializable
|
||||
try:
|
||||
# Attempt to serialize to catch issues early (optional)
|
||||
json.dumps(item[key])
|
||||
result_dict[key] = item[key]
|
||||
except (TypeError, OverflowError):
|
||||
print(
|
||||
f"Warning: Converting non-serializable key '{key}' to string."
|
||||
)
|
||||
result_dict[key] = str(
|
||||
item[key]
|
||||
) # Convert problematic types to string
|
||||
|
||||
payload = {
|
||||
"audio": audio_array,
|
||||
"sampling_rate": 16000,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(server_decode_url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
server_response = response.json()
|
||||
decoded_text = server_response.get("text", "")
|
||||
|
||||
# Add the response to the result dictionary
|
||||
result_dict["response"] = decoded_text
|
||||
print(result_dict)
|
||||
# Write result to JSONL file
|
||||
json.dump(result_dict, outfile, ensure_ascii=False)
|
||||
outfile.write("\n")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"\nError sending request for an item: {e}")
|
||||
error_entry = result_dict # Use the data prepared so far
|
||||
error_entry["error"] = str(e)
|
||||
error_entry["response"] = ""
|
||||
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||
outfile.write("\n")
|
||||
except json.JSONDecodeError:
|
||||
print("\nError decoding server response for an item.")
|
||||
error_entry = result_dict
|
||||
error_entry["error"] = "Invalid JSON response from server"
|
||||
error_entry["response"] = ""
|
||||
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||
outfile.write("\n")
|
||||
except Exception as e:
|
||||
print(f"\nUnexpected error processing an item: {e}")
|
||||
error_entry = result_dict
|
||||
error_entry["error"] = f"Unexpected error: {str(e)}"
|
||||
error_entry["response"] = ""
|
||||
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||
outfile.write("\n")
|
||||
|
||||
# Progress bar updates automatically by iterating over tqdm(dataset)
|
||||
|
||||
# No need to close progress_bar explicitly when iterating directly
|
||||
|
||||
print("Processing finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
813
egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py
Normal file
813
egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py
Normal file
@ -0,0 +1,813 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from datasets import interleave_datasets, load_dataset, Audio, Features, Value, Sequence
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
)
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PerturbSpeed,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from utils import get_local_rank, str2bool
|
||||
import io
|
||||
import wave
|
||||
import random
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=300.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-speed-perturb",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use on-the-fly speed perturbation. "
|
||||
"Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--huggingface-dataset-path-or-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path or name of the Huggingface dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--audio-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The key in the Huggingface dataset containing the audio data",
|
||||
)
|
||||
group.add_argument(
|
||||
"--text-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The key in the Huggingface dataset containing the text data",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
if self.args.on_the_fly_speed_perturb and self.args.on_the_fly_feats:
|
||||
transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
rank = get_local_rank()
|
||||
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 1000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=True if self.args.num_workers > 0 else False,
|
||||
pin_memory=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_valid:
|
||||
CutSet for validation.
|
||||
"""
|
||||
logging.info("About to create dev dataset")
|
||||
rank = get_local_rank()
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
if self.args.bucketing_sampler:
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
else:
|
||||
valid_sampler = SimpleCutSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_num_workers = 1
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=valid_num_workers,
|
||||
persistent_workers=True if valid_num_workers > 0 else False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts_belle(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return {
|
||||
"test": load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
|
||||
)
|
||||
}
|
||||
@lru_cache()
|
||||
def dev_cuts_belle(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
|
||||
)
|
||||
@lru_cache()
|
||||
def train_cuts_belle(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
slam_omni_zh_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_belle_train.jsonl.gz"
|
||||
)
|
||||
return slam_omni_zh_cuts
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_en_vocalnet(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
VoiceAssistant_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
|
||||
)
|
||||
ultrachat_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
|
||||
)
|
||||
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
|
||||
ultrachat_cuts = ultrachat_cuts.resample(16000)
|
||||
return CutSet.mux(
|
||||
VoiceAssistant_cuts,
|
||||
ultrachat_cuts,
|
||||
weights=[
|
||||
len(VoiceAssistant_cuts),
|
||||
len(ultrachat_cuts),
|
||||
],
|
||||
)
|
||||
@lru_cache()
|
||||
def valid_cuts_en_vocalnet(self) -> CutSet:
|
||||
logging.info("About to get valid cuts")
|
||||
VoiceAssistant_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
|
||||
)
|
||||
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
|
||||
return VoiceAssistant_cuts
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts_en_vocalnet(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
VoiceAssistant_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
|
||||
)
|
||||
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
|
||||
return {"test": VoiceAssistant_cuts}
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_ultravox(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
if self.args.huggingface_dataset_path_or_name is not None:
|
||||
librispeech_path = (
|
||||
self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
|
||||
)
|
||||
people_speech_path = (
|
||||
self.args.huggingface_dataset_path_or_name + "/peoples_speech"
|
||||
)
|
||||
gigaspeech_path = self.args.huggingface_dataset_path_or_name + "/gigaspeech"
|
||||
else:
|
||||
librispeech_path = "fixie-ai/librispeech_asr"
|
||||
people_speech_path = "fixie-ai/peoples_speech"
|
||||
gigaspeech_path = "fixie-ai/gigaspeech"
|
||||
# 148_688
|
||||
librispeech_other = load_dataset(
|
||||
librispeech_path, "other", split="train.500", streaming=True
|
||||
)
|
||||
# 104_014
|
||||
librispeech_clean_360 = load_dataset(
|
||||
librispeech_path, "clean", split="train.360", streaming=True
|
||||
)
|
||||
# 28_539
|
||||
librispeech_clean_100 = load_dataset(
|
||||
librispeech_path, "clean", split="train.100", streaming=True
|
||||
)
|
||||
|
||||
# 1_501_271
|
||||
people_speech_clean = load_dataset(
|
||||
people_speech_path, "clean", split="train", streaming=True
|
||||
)
|
||||
# 548_000
|
||||
people_speech_dirty_sa = load_dataset(
|
||||
people_speech_path, "dirty_sa", split="train", streaming=True
|
||||
)
|
||||
|
||||
# 8_266_422
|
||||
|
||||
gigaspeech = load_dataset(
|
||||
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
|
||||
)
|
||||
|
||||
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_100,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_other,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_360,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
||||
gigaspeech, audio_key="audio", text_key="text"
|
||||
)
|
||||
|
||||
people_speech_clean_cuts = CutSet.from_huggingface_dataset(
|
||||
people_speech_clean,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
|
||||
people_speech_dirty_sa,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
librispeech_clean_100_cuts,
|
||||
librispeech_clean_360_cuts,
|
||||
librispeech_other_cuts,
|
||||
gigaspeech_cuts,
|
||||
people_speech_clean_cuts,
|
||||
people_speech_dirty_sa_cuts,
|
||||
weights=[
|
||||
28539,
|
||||
104014,
|
||||
148688,
|
||||
8266422,
|
||||
1501271,
|
||||
548000,
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts_ultravox(self) -> CutSet:
|
||||
logging.info("About to get valid cuts")
|
||||
librispeech_path = "fixie-ai/librispeech_asr"
|
||||
librispeech_clean_valid = load_dataset(
|
||||
librispeech_path, "clean", split="validation", streaming=True
|
||||
)
|
||||
librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_valid,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
return librispeech_clean_valid_cuts
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_librispeech(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
if self.args.huggingface_dataset_path_or_name is not None:
|
||||
librispeech_path = self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
|
||||
else:
|
||||
librispeech_path = "fixie-ai/librispeech_asr"
|
||||
# 148_688
|
||||
librispeech_other = load_dataset(
|
||||
librispeech_path, "other", split="train.500", streaming=True
|
||||
)
|
||||
# 104_014
|
||||
librispeech_clean_360 = load_dataset(
|
||||
librispeech_path, "clean", split="train.360", streaming=True
|
||||
)
|
||||
# 28_539
|
||||
librispeech_clean_100 = load_dataset(
|
||||
librispeech_path, "clean", split="train.100", streaming=True
|
||||
)
|
||||
|
||||
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_100,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_other,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_360,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
librispeech_clean_100_cuts,
|
||||
librispeech_clean_360_cuts,
|
||||
librispeech_other_cuts,
|
||||
weights=[
|
||||
28539,
|
||||
104014,
|
||||
148688,
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_gigaspeech(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
gigaspeech_path = "fixie-ai/gigaspeech"
|
||||
gigaspeech = load_dataset(
|
||||
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
|
||||
)
|
||||
|
||||
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
||||
gigaspeech, audio_key="audio", text_key="text"
|
||||
)
|
||||
|
||||
return gigaspeech_cuts
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_instruct_s2s(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
if self.args.huggingface_dataset_path_or_name is not None:
|
||||
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
|
||||
else:
|
||||
data_path = "yuekai/InstructS2S-200K"
|
||||
# 148_688
|
||||
instruct_s2s_train = load_dataset(
|
||||
data_path, split="train", streaming=True
|
||||
)
|
||||
|
||||
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
|
||||
instruct_s2s_train,
|
||||
audio_key="question_audio",
|
||||
text_key="answer",
|
||||
)
|
||||
|
||||
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
|
||||
|
||||
return instruct_s2s_train_cuts
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_en_speech2speech(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
VoiceAssistant_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
|
||||
)
|
||||
ultrachat_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
|
||||
)
|
||||
|
||||
if self.args.huggingface_dataset_path_or_name is not None:
|
||||
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
|
||||
else:
|
||||
data_path = "yuekai/InstructS2S-200K"
|
||||
# 148_688
|
||||
instruct_s2s_train = load_dataset(
|
||||
data_path, split="train", streaming=True
|
||||
)
|
||||
|
||||
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
|
||||
instruct_s2s_train,
|
||||
audio_key="question_audio",
|
||||
text_key="answer",
|
||||
)
|
||||
|
||||
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
|
||||
|
||||
|
||||
return CutSet.mux(
|
||||
VoiceAssistant_cuts,
|
||||
ultrachat_cuts,
|
||||
instruct_s2s_train_cuts,
|
||||
weights=[
|
||||
len(VoiceAssistant_cuts),
|
||||
len(ultrachat_cuts),
|
||||
423_000,
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_en_speech2speech_librispeech(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
VoiceAssistant_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
|
||||
)
|
||||
ultrachat_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
|
||||
)
|
||||
|
||||
if self.args.huggingface_dataset_path_or_name is not None:
|
||||
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
|
||||
else:
|
||||
data_path = "yuekai/InstructS2S-200K"
|
||||
# 148_688
|
||||
instruct_s2s_train = load_dataset(
|
||||
data_path, split="train", streaming=True
|
||||
)
|
||||
|
||||
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
|
||||
instruct_s2s_train,
|
||||
audio_key="question_audio",
|
||||
text_key="answer",
|
||||
)
|
||||
|
||||
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
|
||||
|
||||
if self.args.huggingface_dataset_path_or_name is not None:
|
||||
librispeech_path = self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
|
||||
else:
|
||||
librispeech_path = "fixie-ai/librispeech_asr"
|
||||
# 148_688
|
||||
librispeech_other = load_dataset(
|
||||
librispeech_path, "other", split="train.500", streaming=True
|
||||
)
|
||||
# 104_014
|
||||
librispeech_clean_360 = load_dataset(
|
||||
librispeech_path, "clean", split="train.360", streaming=True
|
||||
)
|
||||
# 28_539
|
||||
librispeech_clean_100 = load_dataset(
|
||||
librispeech_path, "clean", split="train.100", streaming=True
|
||||
)
|
||||
|
||||
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_100,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_other,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_360,
|
||||
audio_key="audio",
|
||||
text_key="text",
|
||||
)
|
||||
|
||||
|
||||
return CutSet.mux(
|
||||
librispeech_other_cuts,
|
||||
VoiceAssistant_cuts,
|
||||
ultrachat_cuts,
|
||||
librispeech_clean_360_cuts,
|
||||
instruct_s2s_train_cuts,
|
||||
librispeech_clean_100_cuts,
|
||||
weights=[
|
||||
148688,
|
||||
len(VoiceAssistant_cuts),
|
||||
len(ultrachat_cuts),
|
||||
104014,
|
||||
423_000,
|
||||
28539,
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_emilia_en(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
|
||||
# if self.args.huggingface_dataset_path_or_name is not None:
|
||||
# data_path = self.args.huggingface_dataset_path_or_name + "/emilia_en"
|
||||
# else:
|
||||
# data_path = "yuekai/emilia_en"
|
||||
|
||||
emilia_en_data = load_dataset(
|
||||
data_path, split="train", streaming=True
|
||||
)
|
||||
|
||||
def update_wav_path(example):
|
||||
sampling_rate = 16000 # From current_features
|
||||
duration = 1 # seconds, arbitrary duration for random audio
|
||||
num_channels = 1 # mono
|
||||
sample_width = 2 # 2 bytes = 16-bit audio
|
||||
|
||||
num_frames = int(duration * sampling_rate)
|
||||
|
||||
# Generate random bytes for the PCM data part
|
||||
# This will be random noise, but structurally valid for a WAV file
|
||||
pcm_data = bytes([random.randint(0, 255) for _ in range(num_frames * num_channels * sample_width)])
|
||||
|
||||
# Create a WAV file in memory
|
||||
audio_buffer = io.BytesIO()
|
||||
with wave.open(audio_buffer, 'wb') as wf:
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setsampwidth(sample_width)
|
||||
wf.setframerate(sampling_rate)
|
||||
wf.writeframes(pcm_data) # writeframes expects bytes
|
||||
|
||||
example["wav"] = audio_buffer.getvalue()
|
||||
return example
|
||||
|
||||
emilia_en_data = emilia_en_data.map(update_wav_path)
|
||||
current_features = Features({
|
||||
'id': Value('string'),
|
||||
'text': Value('string'),
|
||||
'duration': Value('float'),
|
||||
'language': Value('string'),
|
||||
'dnsmos': Value('float'),
|
||||
'speech_token': Sequence(Value('int32')),
|
||||
'wav': Audio(sampling_rate=16000)
|
||||
|
||||
})
|
||||
emilia_en_data = emilia_en_data.rename_column("code", "speech_token")
|
||||
emilia_en_data = emilia_en_data.cast(current_features)
|
||||
|
||||
emilia_en_train_cuts = CutSet.from_huggingface_dataset(
|
||||
emilia_en_data, # Adjusted from instruct_s2s_train
|
||||
audio_key="wav",
|
||||
text_key="text",
|
||||
)
|
||||
return emilia_en_train_cuts
|
759
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py
Executable file
759
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py
Executable file
@ -0,0 +1,759 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# Command for decoding using fine-tuned models:
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
# Cosyvoice pretrained model for speech token2wav module
|
||||
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
|
||||
# Qwen Pretrained model
|
||||
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
|
||||
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
|
||||
|
||||
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||
python3 ./qwen_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method e2e-epoch10_speech2speech \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path models/CosyVoice-300M-SFT \
|
||||
--use-lora True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import whisper
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from data_module import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from utils import AttributeDict, setup_logger, store_transcripts, write_error_stats
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
def audio_decode_cosyvoice2(
|
||||
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||
):
|
||||
"""
|
||||
Generate audio from tokens with optional tone and prompt embedding.
|
||||
|
||||
Args:
|
||||
audio_tokens (list): List of audio tokens to be processed.
|
||||
model_config: Configuration object containing vocab settings.
|
||||
codec_decoder: Codec decoder for generating audio.
|
||||
tone_dir (str): The tone directory or setting.
|
||||
audio_prompt_path (str, optional): Path to the audio prompt file. Required when tone_dir is not "default_tone".
|
||||
code_layer (int, optional): Number of code layers. Defaults to 1.
|
||||
num_latency_tokens (int, optional): Number of latency tokens to ignore. Defaults to 0.
|
||||
speed (float, optional): Speed factor for audio generation. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Generated audio waveform.
|
||||
"""
|
||||
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||
"empty", prompt_text, prompt_speech_16k, 24000
|
||||
)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token_len=torch.tensor(
|
||||
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||
finalize=True,
|
||||
)
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
|
||||
"""
|
||||
Generate audio from tokens with optional tone and prompt embedding.
|
||||
|
||||
Args:
|
||||
audio_tokens (list): List of audio tokens to be processed.
|
||||
codec_decoder: Codec decoder for generating audio.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Generated audio waveform.
|
||||
"""
|
||||
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
|
||||
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
|
||||
prompt_speech_feat = torch.zeros(1, 0, 80)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
|
||||
prompt_token_len=torch.tensor(
|
||||
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
|
||||
prompt_feat_len=torch.tensor(
|
||||
[prompt_speech_feat.shape[1]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
embedding=flow_embedding.to(codec_decoder.model.device),
|
||||
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
|
||||
)
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def get_model(params, device):
|
||||
"""Load and prepare the speech-to-speech model."""
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16 FIX ME
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
)
|
||||
|
||||
encoder_projector = EncoderProjector(
|
||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||
)
|
||||
|
||||
if params.enable_speech_output:
|
||||
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# TODO: FIX ME
|
||||
# codec_vocab_size = 4096 + 4
|
||||
codec_vocab_size = 6561 + 4
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
intermediate_size=2048,
|
||||
max_position_embeddings=4096,
|
||||
)
|
||||
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
else:
|
||||
codec_lm = None
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
encoder_projector,
|
||||
codec_lm,
|
||||
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
||||
)
|
||||
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg + 1
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
assert "model" not in checkpoint
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
avg_checkpoint = average_checkpoints(filenames)
|
||||
model.load_state_dict(avg_checkpoint, strict=False)
|
||||
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(avg_checkpoint, filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
"""Average a list of checkpoints.
|
||||
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
|
||||
|
||||
Args:
|
||||
filenames:
|
||||
Filenames of the checkpoints to be averaged. We assume all
|
||||
checkpoints are saved by :func:`save_checkpoint`.
|
||||
device:
|
||||
Move checkpoints to this device before averaging.
|
||||
Returns:
|
||||
Return a dict (i.e., state_dict) which is the average of all
|
||||
model state dicts contained in the checkpoints.
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
avg[k] //= n
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="beam-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- beam-search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="beam size for beam search decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="whisper/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
type=str,
|
||||
default="/workspace/CosyVoice-300M-SFT",
|
||||
help="The path to the token2wav model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt_text",
|
||||
type=str,
|
||||
default="Romeo and Juliet might be the most famous act of William Shakespeare.",
|
||||
help="The prompt text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt_speech_path",
|
||||
type=str,
|
||||
default="./assets/common_voice_en_2586258.wav",
|
||||
help="The path to the prompt speech",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict({})
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: AutoTokenizer,
|
||||
token2wav_model: nn.Module,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: "beam-search"
|
||||
- value: A list of lists. Each sublist is a list of token IDs.
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
batch:
|
||||
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
chat_template=TEMPLATE,
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
dtype = torch.float32
|
||||
device = model.llm.device
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||
if not params.remove_whisper_encoder_input_length_restriction:
|
||||
T = 3000
|
||||
if feature.shape[2] < T:
|
||||
feature = torch.cat(
|
||||
[
|
||||
feature,
|
||||
torch.zeros(
|
||||
feature.shape[0], feature.shape[1], T - feature.shape[2]
|
||||
).to(device, dtype=dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
# chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
# questions_with_history = [
|
||||
# cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||
# ]
|
||||
# history_contexts = [
|
||||
# question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
|
||||
# ]
|
||||
# last_questions = [
|
||||
# question.split("<USER>: ")[-1].strip() for question in questions_with_history
|
||||
# ]
|
||||
# messages = []
|
||||
# for i, total_round in enumerate(chat_rounds):
|
||||
# message = []
|
||||
# if total_round > 1:
|
||||
# history_question_answer = history_contexts[i].split("USER:")
|
||||
# history_question_answer = [item for item in history_question_answer if item]
|
||||
# for j in range(total_round - 1):
|
||||
# question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||
# message += [
|
||||
# {"role": "user", "content": question_answer[0].strip()},
|
||||
# {"role": "assistant", "content": question_answer[1].strip()},
|
||||
# ]
|
||||
# message += [
|
||||
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
# {"role": "assistant", "content": ""},
|
||||
# ]
|
||||
# print(f"message: {message}, batch_size {len(chat_rounds)}")
|
||||
# messages.append(message)
|
||||
messages = []
|
||||
for i in range(len(batch["supervisions"]["cut"])):
|
||||
message = [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
messages.append(message)
|
||||
input_ids, attention_mask = preprocess(messages, tokenizer)
|
||||
if params.enable_speech_output:
|
||||
generated_ids, generated_speech_output = model.decode_with_speech_output(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
generated_speech_output = [
|
||||
generated_speech_output
|
||||
] # WAR: only support batch = 1 for now
|
||||
for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
|
||||
speech_file_name = params.log_dir / f"{cut_id}.wav"
|
||||
# audio_tokens = [token for token in audio_tokens if token < 4096]
|
||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||
if "CosyVoice2" in params.token2wav_path:
|
||||
prompt_speech_16k = load_wav(params.prompt_speech_path, 16000)
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
audio_tokens,
|
||||
params.prompt_text,
|
||||
prompt_speech_16k,
|
||||
token2wav_model,
|
||||
)
|
||||
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
|
||||
else:
|
||||
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
||||
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
|
||||
else:
|
||||
generated_ids = model.decode(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
|
||||
print(f"hyps: {hyps}")
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: AutoTokenizer,
|
||||
token2wav_model: nn.Module,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
The dataloader.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
# questions_with_history = [
|
||||
# cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||
# ]
|
||||
# texts = [
|
||||
# question.split("<USER>: ")[-1].strip()
|
||||
# for question in questions_with_history
|
||||
# ]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
token2wav_model=token2wav_model,
|
||||
batch=batch,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
print(f"ref: {ref_text}")
|
||||
print(f"hyp: {''.join(hyp_words)}")
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.log_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.log_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
|
||||
params.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
||||
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model, tokenizer = get_model(params, device)
|
||||
if "CosyVoice2" in params.token2wav_path:
|
||||
token2wav_model = CosyVoice2(
|
||||
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
else:
|
||||
token2wav_model = CosyVoice(
|
||||
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
args.return_cuts = True
|
||||
data_module = AsrDataModule(args)
|
||||
|
||||
def remove_long_utt(c: Cut):
|
||||
# Keep only utterances with duration in 30 seconds
|
||||
#
|
||||
if c.duration > 30.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
# TODO: FIX ME
|
||||
# test_sets_cuts = data_module.test_cuts_belle()
|
||||
test_sets_cuts = data_module.test_cuts_en_vocalnet()
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dls = [
|
||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
|
||||
for cuts_name in test_sets
|
||||
]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
token2wav_model=token2wav_model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
256
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py
Normal file
256
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py
Normal file
@ -0,0 +1,256 @@
|
||||
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||
# 2025 (authors: Yuekai Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
|
||||
""" Example Usage
|
||||
split=test_zh
|
||||
llm_path=f5-tts/exp_zh/checkpoint-805000
|
||||
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
|
||||
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
|
||||
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
||||
vocoder=./bigvgan_v2_24khz_100band_256x
|
||||
torchrun --nproc_per_node=2 \
|
||||
f5-tts/infer_dist.py \
|
||||
--output_dir $output_dir \
|
||||
--batch_size 1 \
|
||||
--num_workers 2 \
|
||||
--llm-model-name-or-path $llm_path \
|
||||
--flow-matching-model-path $model_path \
|
||||
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||
--use-cosyvoice-semantic-token True \
|
||||
--vocoder-dir $vocoder \
|
||||
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
|
||||
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import whisper
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoTokenizer
|
||||
from web_demo import get_model
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
|
||||
# sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="extract speech code")
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="test",
|
||||
help="huggingface dataset split name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset-name",
|
||||
type=str,
|
||||
default="commoneval",
|
||||
help="subset name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=str, help="dir to save result"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=2, help="workers for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefetch", type=int, default=2, help="prefetch for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint name or path, default to %(default)r",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--top-k",
|
||||
# type=int,
|
||||
# default=50,
|
||||
# help="top k for sampling",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--top-p",
|
||||
# type=float,
|
||||
# default=0.95,
|
||||
# help="top p for sampling",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--temperature",
|
||||
# type=float,
|
||||
# default=0.8,
|
||||
# help="temperature for sampling",
|
||||
# )
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group("nccl")
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer,
|
||||
):
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
chat_template=TEMPLATE,
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def custom_collate(batch):
|
||||
assert len(batch) == 1
|
||||
audio = batch[0]["audio"]
|
||||
assert audio["sampling_rate"] == 16000
|
||||
result = {"audio": audio["array"]}
|
||||
for keys in batch[0].keys():
|
||||
if keys != "audio":
|
||||
result[keys] = batch[0][keys]
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
dataset = load_dataset(
|
||||
"hlt-lab/voicebench",
|
||||
args.subset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
model, tokenizer = get_model(args)
|
||||
# tokenizer = AutoTokenizer.from_pretrained(args.llm_path_or_name)
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=custom_collate,
|
||||
)
|
||||
|
||||
total_steps = len(dataset)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
message = [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
input_ids, attention_mask = preprocess([message], tokenizer)
|
||||
results_jsonl_file = open(
|
||||
os.path.join(
|
||||
args.output_dir,
|
||||
f"results-{args.subset_name}-{args.split_name}-{rank}-audio.jsonl",
|
||||
),
|
||||
"w",
|
||||
)
|
||||
for batch in dataloader:
|
||||
audio = batch["audio"]
|
||||
audio = torch.from_numpy(audio).to(device).to(torch.float32)
|
||||
fbank = whisper.log_mel_spectrogram(audio, device=device)
|
||||
fbank = fbank.unsqueeze(0)
|
||||
generated_ids = model.decode(
|
||||
fbank, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
result_dict = {}
|
||||
for key in batch.keys():
|
||||
if key != "audio":
|
||||
result_dict[key] = batch[key]
|
||||
result_dict["response"] = hyps[0]
|
||||
json.dump(result_dict, results_jsonl_file)
|
||||
results_jsonl_file.write("\n")
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * args.batch_size)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
310
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_tts.py
Executable file
310
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_tts.py
Executable file
@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
# Qwen Pretrained model
|
||||
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from datasets import Audio, load_dataset
|
||||
from decode import audio_decode_cosyvoice2
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from train import add_model_arguments, add_training_arguments, get_model, get_params
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Qwen2Config,
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from utils import ( # filter_uneven_sized_batch,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_local_rank,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
# sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="test_en",
|
||||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||
help="huggingface dataset split name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
type=str,
|
||||
default="/workspace/CosyVoice-300M-SFT",
|
||||
help="The path to the token2wav model",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
add_training_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="longest", # FIX me change padding to longest
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(f"Remove too long text, {messages} ")
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
# + 2 to skip: 'assistant', '\n'
|
||||
# WAR: TODO FIXME check qwen3
|
||||
# THIS IS THE ONLY DIFFERENCE FROM preprocess
|
||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||
target_ids[row, col] = default_speech_token_id
|
||||
# remove default_speech_token_id from target_ids and input_ids
|
||||
batch_size = target_ids.size(0)
|
||||
|
||||
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
|
||||
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
|
||||
def data_collator(batch):
|
||||
prompt_texts, prompt_speech_16k, messages, ids, target_texts = [], [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
# speech_tokens.append(item["prompt_audio_cosy2_tokens"])
|
||||
message_list_item = []
|
||||
message_list_item += [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}",
|
||||
},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
messages.append(message_list_item)
|
||||
target_texts.append(item["target_text"])
|
||||
|
||||
ids.append(item["id"])
|
||||
prompt_texts.append(item["prompt_text"])
|
||||
speech_org = item["prompt_audio"]
|
||||
|
||||
speech_org = torch.tensor(speech_org["array"], dtype=torch.float32).unsqueeze(0)
|
||||
speech_org = speech_org.mean(dim=0, keepdim=True)
|
||||
prompt_speech_16k.append(speech_org)
|
||||
|
||||
# resample to 16k
|
||||
|
||||
return {
|
||||
"prompt_texts": prompt_texts,
|
||||
"target_texts": target_texts,
|
||||
"prompt_speech_16k": prompt_speech_16k,
|
||||
"messages": messages,
|
||||
"ids": ids,
|
||||
}
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.log_dir = Path(params.exp_dir) / "log-results-wav"
|
||||
params.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
if rank == 0:
|
||||
setup_logger(f"{params.exp_dir}/log/log-decode-tts")
|
||||
logging.info(params)
|
||||
logging.info("About to create model")
|
||||
model, tokenizer = get_model(params)
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", get_local_rank())
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
|
||||
dataset = dataset.cast_column("prompt_audio", Audio(sampling_rate=16000))
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=params.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
prefetch_factor=1,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
token2wav_model = CosyVoice2(
|
||||
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
for batch in data_loader:
|
||||
messages = batch["messages"]
|
||||
prompt_texts = batch["prompt_texts"]
|
||||
prompt_speech_16k = batch["prompt_speech_16k"]
|
||||
target_texts = batch["target_texts"]
|
||||
ids = batch["ids"]
|
||||
input_ids, attention_mask, _ = preprocess(messages, tokenizer)
|
||||
generated_ids, generated_speech_output = model.decode_with_speech_output(
|
||||
None, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
generated_speech_output = [
|
||||
generated_speech_output
|
||||
] # WAR: only support batch = 1 for now
|
||||
for cut_id, audio_tokens, prompt_text, prompt_speech, target_text in zip(
|
||||
ids, generated_speech_output, prompt_texts, prompt_speech_16k, target_texts
|
||||
):
|
||||
speech_file_name = params.log_dir / f"{cut_id}.wav"
|
||||
# save target_text to file
|
||||
with open(params.log_dir / f"{cut_id}.txt", "w") as f:
|
||||
f.write(f"{target_text}\n")
|
||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||
if "CosyVoice2" in params.token2wav_path:
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
audio_tokens,
|
||||
prompt_text,
|
||||
prompt_speech,
|
||||
token2wav_model,
|
||||
)
|
||||
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 24000)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = get_world_size()
|
||||
rank = get_rank()
|
||||
|
||||
torch.set_num_threads(1)
|
||||
# torch.set_num_interop_threads(1)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json
Symbolic link
1
egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR_LLM/whisper_llm_zh/ds_config_zero1.json
|
1
egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py
Symbolic link
1
egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
838
egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py
Normal file
838
egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py
Normal file
@ -0,0 +1,838 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
import logging
|
||||
|
||||
|
||||
class EncoderProjector(nn.Module):
|
||||
"""
|
||||
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
||||
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
|
||||
Args:
|
||||
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
|
||||
llm_dim (:obj:`int`): The dimension of the language model.
|
||||
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
|
||||
super().__init__()
|
||||
self.downsample_rate = downsample_rate
|
||||
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(llm_dim, llm_dim)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
batch_size, seq_len, feat_dim = x.size()
|
||||
num_frames_to_discard = seq_len % self.downsample_rate
|
||||
if num_frames_to_discard > 0:
|
||||
x = x[:, :-num_frames_to_discard, :]
|
||||
seq_len = x.size(1)
|
||||
|
||||
x = x.contiguous()
|
||||
x = x.view(
|
||||
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
|
||||
)
|
||||
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
class SPEECH_LLM(nn.Module):
|
||||
"""
|
||||
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
|
||||
The encoder is used to extract speech features from the input speech signal.
|
||||
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
|
||||
The language model is used to generate the text from the speech features.
|
||||
Args:
|
||||
encoder (:obj:`nn.Module`): The encoder module.
|
||||
llm (:obj:`nn.Module`): The language model module.
|
||||
encoder_projector (:obj:`nn.Module`): The encoder projector module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: nn.Module = None,
|
||||
llm: nn.Module = None,
|
||||
encoder_projector: nn.Module = None,
|
||||
codec_lm: nn.Module = None,
|
||||
codec_lm_padding_side: str = "left",
|
||||
teacher_llm: nn.Module = None,
|
||||
kl_temperature: float = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.llm = llm
|
||||
self.encoder_projector = encoder_projector
|
||||
self.codec_lm = codec_lm
|
||||
if self.codec_lm:
|
||||
self.speech_token_projector = nn.Linear(
|
||||
self.llm.config.hidden_size + self.llm.config.hidden_size,
|
||||
self.codec_lm.config.hidden_size,
|
||||
)
|
||||
self.codec_lm_head = nn.Linear(
|
||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||
)
|
||||
self.speech_token_projector = self.speech_token_projector.to(
|
||||
dtype=torch.float16
|
||||
)
|
||||
self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16)
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss()
|
||||
self.codec_lm_padding_side = codec_lm_padding_side
|
||||
|
||||
self.audio_accuracy_metric = MulticlassAccuracy(
|
||||
self.codec_lm.vocab_size,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index=IGNORE_TOKEN_ID,
|
||||
)
|
||||
if teacher_llm is not None:
|
||||
self.teacher_llm = teacher_llm
|
||||
self.kl_temperature = kl_temperature
|
||||
|
||||
def _merge_input_ids_with_speech_features(
|
||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||
):
|
||||
"""
|
||||
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
|
||||
with the speech features and padding the input_ids to the maximum length of the speech features.
|
||||
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
|
||||
Args:
|
||||
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
|
||||
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
|
||||
input_ids (:obj:`torch.Tensor`): The input ids to merge.
|
||||
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
|
||||
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
|
||||
Returns:
|
||||
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
|
||||
"""
|
||||
num_speechs, speech_len, embed_dim = speech_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(
|
||||
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
|
||||
)
|
||||
# 1. Create a mask to know where special speech tokens are
|
||||
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
|
||||
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (
|
||||
num_special_speech_tokens.max() * (speech_len - 1)
|
||||
) + sequence_length
|
||||
batch_indices, non_speech_indices = torch.where(
|
||||
input_ids != self.llm.config.default_speech_token_id
|
||||
)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged speech-text sequence.
|
||||
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
|
||||
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = (
|
||||
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
|
||||
)
|
||||
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_speech_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size,
|
||||
max_embed_dim,
|
||||
embed_dim,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size,
|
||||
max_embed_dim,
|
||||
dtype=attention_mask.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim),
|
||||
IGNORE_TOKEN_ID,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_speech_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_speech_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
|
||||
batch_indices, non_speech_indices
|
||||
]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
|
||||
batch_indices, non_speech_indices
|
||||
]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[
|
||||
batch_indices, non_speech_indices
|
||||
]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
|
||||
speech_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim),
|
||||
True,
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
speech_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
|
||||
:, None
|
||||
].to(target_device)
|
||||
|
||||
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
|
||||
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[speech_to_overwrite] = (
|
||||
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
)
|
||||
final_attention_mask |= speech_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
|
||||
(final_attention_mask == 0), 1
|
||||
)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(
|
||||
input_ids == self.llm.config.pad_token_id
|
||||
)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc = compute_accuracy(
|
||||
preds.detach()[:, :-1],
|
||||
labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
return model_outputs.loss, acc
|
||||
|
||||
def forward_kl_div(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
teacher_input_ids: torch.LongTensor = None,
|
||||
teacher_attention_mask: torch.Tensor = None,
|
||||
teacher_labels: torch.LongTensor = None,
|
||||
):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
|
||||
)
|
||||
|
||||
teacher_outputs = self.teacher_llm(
|
||||
input_ids=teacher_input_ids,
|
||||
attention_mask=teacher_attention_mask,
|
||||
)
|
||||
|
||||
kl_loss = torch.nn.functional.kl_div(
|
||||
torch.nn.functional.log_softmax(
|
||||
model_outputs.logits[labels != -100] / self.kl_temperature,
|
||||
dim=-1,
|
||||
),
|
||||
torch.nn.functional.softmax(
|
||||
teacher_outputs.logits[teacher_labels != -100] / self.kl_temperature,
|
||||
dim=-1,
|
||||
),
|
||||
reduction="batchmean",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
teacher_preds = torch.argmax(teacher_outputs.logits, -1)
|
||||
acc = compute_accuracy(
|
||||
preds.detach()[:, :-1],
|
||||
labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
acc_teacher = compute_accuracy(
|
||||
teacher_preds.detach()[:, :-1],
|
||||
teacher_labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
return kl_loss, acc, acc_teacher
|
||||
|
||||
def forward_with_speech_output(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
speech_codec_ids: torch.LongTensor = None,
|
||||
):
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
if fbank is not None:
|
||||
encoder_outs = self.encoder(fbank)
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
input_seq_len = attention_mask.sum(dim=1) # shape, B
|
||||
(
|
||||
text_label_start_index_list,
|
||||
text_input_start_index_list,
|
||||
input_question_len_list,
|
||||
) = ([], [], [])
|
||||
for i in range(labels.shape[0]):
|
||||
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
|
||||
input_embeds_start_index = input_embeds_valid_index[0]
|
||||
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
|
||||
text_labels_start_index = text_labels_valid_index[0]
|
||||
|
||||
assert (
|
||||
input_seq_len[i]
|
||||
== input_embeds_valid_index[-1] - input_embeds_start_index + 1
|
||||
), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
|
||||
assert (
|
||||
input_embeds_valid_index[-1] == text_labels_valid_index[-1]
|
||||
), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
|
||||
input_question_len = text_labels_start_index - input_embeds_start_index
|
||||
assert (
|
||||
input_question_len
|
||||
+ text_labels_valid_index[-1]
|
||||
- text_labels_start_index
|
||||
+ 1
|
||||
== input_seq_len[i]
|
||||
)
|
||||
text_label_start_index_list.append(text_labels_start_index)
|
||||
text_input_start_index_list.append(input_embeds_start_index)
|
||||
input_question_len_list.append(input_question_len)
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
text_loss = model_outputs.loss
|
||||
delay_step = 1
|
||||
# prepare codec lm inputs
|
||||
audio_codes_lens = [
|
||||
len(x) + input_question_len_list[i] + delay_step + 1
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
max_len_speech_codec = max(audio_codes_lens)
|
||||
|
||||
if self.codec_lm_padding_side == "right":
|
||||
audio_codes = [
|
||||
[self.codec_lm.config.mask_token_id]
|
||||
* (input_question_len_list[i] + delay_step)
|
||||
+ [self.codec_lm.config.bos_token_id]
|
||||
+ x
|
||||
+ [self.codec_lm.config.pad_token_id]
|
||||
* (max_len_speech_codec - audio_codes_lens[i])
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_labels = [
|
||||
[self.codec_lm.config.pad_token_id]
|
||||
* (input_question_len_list[i] + delay_step)
|
||||
+ x
|
||||
+ [self.codec_lm.config.eos_token_id]
|
||||
+ [self.codec_lm.config.pad_token_id]
|
||||
* (max_len_speech_codec - audio_codes_lens[i])
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
audio_codes = [
|
||||
[self.codec_lm.config.pad_token_id]
|
||||
* (max_len_speech_codec - audio_codes_lens[i])
|
||||
+ [self.codec_lm.config.mask_token_id]
|
||||
* (input_question_len_list[i] + delay_step)
|
||||
+ [self.codec_lm.config.bos_token_id]
|
||||
+ x
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_labels = [
|
||||
[self.codec_lm.config.pad_token_id]
|
||||
* (max_len_speech_codec - audio_codes_lens[i])
|
||||
+ [self.codec_lm.config.pad_token_id]
|
||||
* (input_question_len_list[i] + delay_step)
|
||||
+ x
|
||||
+ [self.codec_lm.config.eos_token_id]
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_codes = torch.tensor(
|
||||
audio_codes, dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
audio_labels = torch.tensor(
|
||||
audio_labels, dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||
|
||||
# text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
||||
text_input_embeds_list = []
|
||||
for i in range(len(text_label_start_index_list)):
|
||||
text_last_hidden = model_outputs.hidden_states[-1][
|
||||
i,
|
||||
text_input_start_index_list[i] : text_input_start_index_list[i]
|
||||
+ input_seq_len[i]
|
||||
- 1,
|
||||
]
|
||||
# text_last_hidden_lists.append(text_last_hidden)
|
||||
text_embed = inputs_embeds[
|
||||
i,
|
||||
text_input_start_index_list[i]
|
||||
+ 1 : text_input_start_index_list[i]
|
||||
+ input_seq_len[i],
|
||||
] # exclude bos
|
||||
# text_embeds_list.append(text_embed)
|
||||
|
||||
text_input_embeds = torch.cat(
|
||||
[
|
||||
text_last_hidden,
|
||||
text_embed,
|
||||
],
|
||||
dim=-1,
|
||||
) # shape, T, D1 + D2
|
||||
text_input_embeds = self.speech_token_projector(
|
||||
text_input_embeds
|
||||
) # shape, T, D_codec
|
||||
text_input_embeds_list.append(text_input_embeds)
|
||||
|
||||
for i in range(audio_embeddings.shape[0]):
|
||||
text_input_embeds = text_input_embeds_list[i]
|
||||
if self.codec_lm_padding_side == "right":
|
||||
audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
start_idx = torch.where(
|
||||
audio_codes[i] == self.codec_lm.config.mask_token_id
|
||||
)[0][0]
|
||||
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
|
||||
assert (
|
||||
start_idx == start_idx_re_compute
|
||||
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
||||
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
||||
logging.warning(
|
||||
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}\naudio_codes_lens: {audio_codes_lens[i]}\ninput_question_len_list: {input_question_len_list[i]}\ninput_seq_len: {input_seq_len[i]}\n"
|
||||
)
|
||||
# breakpoint()
|
||||
text_input_embeds = text_input_embeds[
|
||||
: audio_embeddings.shape[1] - start_idx
|
||||
]
|
||||
audio_embeddings[
|
||||
i, start_idx : start_idx + text_input_embeds.shape[0]
|
||||
] += text_input_embeds
|
||||
|
||||
speech_outputs = self.codec_lm(
|
||||
attention_mask=audio_attention_mask,
|
||||
inputs_embeds=audio_embeddings,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
last_hidden_state = speech_outputs.hidden_states[-1].clone()
|
||||
|
||||
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
|
||||
audio_logits = audio_logits.contiguous().view(
|
||||
-1, self.codec_lm.config.vocab_size
|
||||
)
|
||||
audio_labels = audio_labels.contiguous().view(-1)
|
||||
audio_labels = audio_labels.masked_fill(
|
||||
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
|
||||
)
|
||||
codec_loss = self.loss_fct(audio_logits, audio_labels)
|
||||
audio_preds = torch.argmax(audio_logits, -1)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc = compute_accuracy(
|
||||
preds.detach()[:, :-1],
|
||||
labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
audio_acc = compute_accuracy(
|
||||
audio_preds.detach(),
|
||||
audio_labels.detach(),
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
audio_topk_acc = self.audio_accuracy_metric(
|
||||
audio_logits.detach(), audio_labels.detach()
|
||||
).item()
|
||||
|
||||
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
|
||||
|
||||
def decode(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
encoder_outs = self.encoder(fbank)
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
speech_features = speech_features.to(torch.float16)
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
_,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask
|
||||
)
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
do_sample=kwargs.get("do_sample", True),
|
||||
min_length=kwargs.get("min_length", 1),
|
||||
top_p=kwargs.get("top_p", 0.5),
|
||||
top_k=kwargs.get("top_k", 20),
|
||||
repetition_penalty=kwargs.get("repetition_penalty", 1.1),
|
||||
temperature=kwargs.get("temperature", 0.7),
|
||||
bos_token_id=self.llm.config.bos_token_id,
|
||||
eos_token_id=self.llm.config.eos_token_id,
|
||||
pad_token_id=self.llm.config.pad_token_id,
|
||||
)
|
||||
|
||||
return generated_ids
|
||||
|
||||
def decode_with_speech_output(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None, # Prompt input_ids
|
||||
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
||||
max_text_new_tokens: int = 1024,
|
||||
max_speech_new_tokens: int = 2048, # Max length for speech tokens
|
||||
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
||||
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
||||
) -> Tuple[torch.LongTensor, List[List[int]]]:
|
||||
"""
|
||||
Generates text and corresponding speech tokens using the revised logic.
|
||||
|
||||
Args:
|
||||
fbank: Input audio features.
|
||||
input_ids: Input token IDs for the text prompt.
|
||||
attention_mask: Attention mask for the text prompt.
|
||||
max_text_new_tokens: Max new tokens for text generation.
|
||||
max_speech_new_tokens: Max new tokens for speech generation.
|
||||
llm_kwargs: Additional arguments for self.llm.generate.
|
||||
codec_lm_kwargs: Additional arguments for self.codec_lm.generate.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.LongTensor, List[List[int]]]:
|
||||
- generated_text_ids: Tensor of generated text token IDs (including prompt).
|
||||
- generated_speech_tokens: List of lists, where each inner list contains
|
||||
the generated speech codec tokens for a batch item.
|
||||
"""
|
||||
batch_size = input_ids.shape[0]
|
||||
assert batch_size == 1, "Batch size must be 1 for speech generation."
|
||||
|
||||
device = next(self.parameters()).device # Use model's device
|
||||
|
||||
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
# Merge speech features with prompt embeddings
|
||||
if fbank is not None:
|
||||
encoder_outs = self.encoder(fbank)
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
|
||||
(
|
||||
merged_prompt_inputs_embeds,
|
||||
merged_prompt_attention_mask,
|
||||
_,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, prompt_embeds, input_ids, attention_mask
|
||||
)
|
||||
else:
|
||||
merged_prompt_inputs_embeds = prompt_embeds
|
||||
merged_prompt_attention_mask = attention_mask
|
||||
|
||||
# --- 2. Generate Text using LLM ---
|
||||
# Use merged embeds/mask as input to generate
|
||||
# Ensure kwargs passed are suitable for llm.generate
|
||||
# Note: Using default generation params from `decode` if not provided in kwargs
|
||||
final_llm_kwargs = {
|
||||
"bos_token_id": self.llm.config.bos_token_id,
|
||||
"eos_token_id": self.llm.config.eos_token_id,
|
||||
"pad_token_id": self.llm.config.pad_token_id,
|
||||
"num_beams": 1,
|
||||
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
|
||||
"top_p": 0.5,
|
||||
"top_k": 20,
|
||||
"repetition_penalty": 1.1,
|
||||
"temperature": 0.7,
|
||||
**(llm_kwargs or {}), # User-provided kwargs override defaults
|
||||
}
|
||||
|
||||
text_outputs = self.llm.generate(
|
||||
inputs_embeds=merged_prompt_inputs_embeds,
|
||||
attention_mask=merged_prompt_attention_mask,
|
||||
max_new_tokens=max_text_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_hidden_states=True,
|
||||
**final_llm_kwargs,
|
||||
)
|
||||
delay_step = 1
|
||||
generated_text_ids = text_outputs.sequences # [B, S_full]
|
||||
eos_token_id = self.llm.config.eos_token_id
|
||||
eos_token_embedding = self.llm.get_input_embeddings()(
|
||||
torch.tensor([[eos_token_id]], device=device)
|
||||
)
|
||||
assert (
|
||||
generated_text_ids[0, -1] == eos_token_id
|
||||
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
||||
thinker_token_embeds_org = [
|
||||
token_hidden_states[0].to(self.llm.device)
|
||||
for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
|
||||
first_thinker_token_embed = torch.cat(
|
||||
[
|
||||
thinker_token_embeds_org[0][:, 1:],
|
||||
thinker_token_embeds_org[1],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
thinker_token_embeds = (
|
||||
[first_thinker_token_embed]
|
||||
+ thinker_token_embeds_org[2:]
|
||||
+ [eos_token_embedding]
|
||||
)
|
||||
thinker_hidden_states = [
|
||||
token_hidden_states[-1].to(self.llm.device)
|
||||
for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
|
||||
thinker_reply_part = [
|
||||
torch.cat(
|
||||
[
|
||||
thinker_hidden_state,
|
||||
thinker_token_embed,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
for thinker_hidden_state, thinker_token_embed in zip(
|
||||
thinker_hidden_states[1:], thinker_token_embeds[1:]
|
||||
)
|
||||
]
|
||||
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
|
||||
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||
thinker_prompt_part = torch.cat(
|
||||
[
|
||||
thinker_hidden_states[0],
|
||||
thinker_token_embeds[0],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
|
||||
thinker_reply_part = self.speech_token_projector(thinker_reply_part)
|
||||
|
||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||
talker_input_ids = torch.full(
|
||||
(batch_size, thinker_prompt_part_seq_len + delay_step + 1),
|
||||
self.codec_lm.config.mask_token_id,
|
||||
dtype=torch.long,
|
||||
device=self.llm.device,
|
||||
)
|
||||
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
|
||||
thinker_input_embeds = torch.cat(
|
||||
[
|
||||
thinker_prompt_part,
|
||||
thinker_reply_part[:, : delay_step + 1, :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
talker_inputs_embeds += thinker_input_embeds
|
||||
thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
|
||||
|
||||
past_key_values = None
|
||||
|
||||
generated_speech_tokens_list = []
|
||||
next_token_ids = None
|
||||
|
||||
for t in range(max_speech_new_tokens):
|
||||
if t > 0:
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
||||
next_token_ids
|
||||
)
|
||||
if thinker_reply_part.shape[1] > 0:
|
||||
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||
thinker_reply_part = thinker_reply_part[:, 1:, :]
|
||||
|
||||
codec_outputs = self.codec_lm(
|
||||
inputs_embeds=talker_inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
|
||||
next_token_logits = self.codec_lm_head(last_token_hidden_state)
|
||||
|
||||
next_token_ids = topk_sampling(
|
||||
next_token_logits,
|
||||
)
|
||||
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
||||
break
|
||||
|
||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||
generated_speech_tokens_list.append(
|
||||
next_token_ids.squeeze(1).cpu().tolist()[0]
|
||||
)
|
||||
|
||||
return generated_text_ids, generated_speech_tokens_list
|
||||
|
||||
|
||||
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
|
||||
Args:
|
||||
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(
|
||||
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
|
||||
)
|
||||
denominator = torch.sum(mask)
|
||||
return numerator.float() / denominator.float()
|
||||
|
||||
|
||||
def topk_sampling(
|
||||
logits,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
temperature=0.8,
|
||||
):
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
# Top-p/top-k filtering
|
||||
logits_filtered = top_k_top_p_filtering(
|
||||
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
)
|
||||
# Sample
|
||||
probs = torch.nn.functional.softmax(logits_filtered, dim=-1)
|
||||
tokens = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||||
):
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
if top_k > 0:
|
||||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
if min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
@ -0,0 +1,23 @@
|
||||
conformer==0.3.2
|
||||
diffusers==0.29.0
|
||||
gdown==5.1.0
|
||||
gradio
|
||||
hydra-core==1.3.2
|
||||
HyperPyYAML==1.2.2
|
||||
inflect==7.3.1
|
||||
librosa==0.10.2
|
||||
lightning==2.2.4
|
||||
matplotlib==3.7.5
|
||||
#modelscope==1.15.0
|
||||
networkx==3.1
|
||||
omegaconf==2.3.0
|
||||
onnx==1.16.0
|
||||
onnxruntime-gpu==1.18.0
|
||||
protobuf==4.25
|
||||
pydantic==2.7.0
|
||||
pyworld==0.3.4
|
||||
rich==13.7.1
|
||||
soundfile==0.12.1
|
||||
tensorboard==2.14.0
|
||||
wget==3.2
|
||||
WeTextProcessing==1.0.3
|
15
egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt
Normal file
15
egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt
Normal file
@ -0,0 +1,15 @@
|
||||
openai-whisper
|
||||
kaldialign
|
||||
lhotse
|
||||
sentencepiece
|
||||
pypinyin
|
||||
tensorboard
|
||||
librosa
|
||||
deepspeed
|
||||
transformers>=4.37.0
|
||||
flash-attn
|
||||
peft
|
||||
torchmetrics
|
||||
# triton==3.3.0 # may be violate with openai-whisper
|
||||
gradio
|
||||
sherpa-onnx
|
131
egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py
Normal file
131
egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py
Normal file
@ -0,0 +1,131 @@
|
||||
# server.py
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
import whisper
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoTokenizer
|
||||
from web_demo import get_model
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="extract speech code")
|
||||
parser.add_argument(
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint name or path, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8001,
|
||||
help="Port number",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class SpeechRequest(BaseModel):
|
||||
audio: List[float] # Expecting audio as a list of floats (raw waveform)
|
||||
sampling_rate: int = 16000
|
||||
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
def preprocess_prompt(tokenizer):
|
||||
"""Preprocesses the prompt template."""
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
message, # Using the hardcoded message
|
||||
tokenize=True,
|
||||
add_generation_prompt=False, # Important for generation
|
||||
chat_template=TEMPLATE,
|
||||
padding=False, # No padding needed for single prompt
|
||||
truncation=False,
|
||||
)
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.long)
|
||||
attention_mask = torch.ones_like(
|
||||
input_ids, dtype=torch.bool
|
||||
) # Mask is all True for the prompt
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
args = get_args()
|
||||
print(f"Using port: {args.port}")
|
||||
model, tokenizer = get_model(args)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda")
|
||||
if args.prompt_template is None:
|
||||
template = f"{DEFAULT_SPEECH_TOKEN}"
|
||||
elif args.prompt_template == "qa":
|
||||
template = f"Answer the following question:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||
elif args.prompt_template == "continuation":
|
||||
template = f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||
elif args.prompt_template == "asr":
|
||||
template = (
|
||||
f"Repeat the following text, without any explanation: {DEFAULT_SPEECH_TOKEN}"
|
||||
)
|
||||
elif args.prompt_template == "mt":
|
||||
template = f"Please translate the text to Chinese. Your response should only include the Chinese translation, without any additional words:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||
else:
|
||||
raise ValueError(f"Invalid prompt template: {args.prompt_template}")
|
||||
print("Using template:", template)
|
||||
message = [
|
||||
{"role": "user", "content": template},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
prompt_input_ids, prompt_attention_mask = preprocess_prompt(tokenizer)
|
||||
prompt_input_ids = prompt_input_ids.to(device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
|
||||
@app.post("/decode", response_model=TextResponse)
|
||||
async def decode_speech(request: SpeechRequest):
|
||||
"""
|
||||
Receives audio waveform, processes it, and returns the decoded text.
|
||||
"""
|
||||
if request.sampling_rate != 16000:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Only 16kHz sampling rate is supported."
|
||||
)
|
||||
|
||||
try:
|
||||
audio_tensor = torch.tensor(request.audio, dtype=torch.float32).to(device)
|
||||
fbank = whisper.log_mel_spectrogram(audio_tensor, device=device, n_mels=80)
|
||||
fbank = fbank.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.decode(fbank, prompt_input_ids, prompt_attention_mask)
|
||||
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
response_text = hyps[0] if hyps else ""
|
||||
|
||||
return TextResponse(text=response_text)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during processing: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting server...")
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
1160
egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py
Executable file
1160
egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py
Executable file
File diff suppressed because it is too large
Load Diff
604
egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py
Executable file
604
egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py
Executable file
@ -0,0 +1,604 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
# Qwen Pretrained model
|
||||
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import deepspeed
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Qwen2Config,
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
from pathlib import Path
|
||||
|
||||
from train import add_model_arguments, add_training_arguments, get_params, get_model
|
||||
from utils import ( # filter_uneven_sized_batch,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_local_rank,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The batch size to use.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
add_model_arguments(parser)
|
||||
add_training_arguments(parser)
|
||||
return parser
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="longest", # FIX me change padding to longest
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(f"Remove too long text, {messages} ")
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
# + 2 to skip: 'assistant', '\n'
|
||||
# WAR: TODO FIXME check qwen3
|
||||
# THIS IS THE ONLY DIFFERENCE FROM preprocess
|
||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||
target_ids[row, col] = default_speech_token_id
|
||||
# remove default_speech_token_id from target_ids and input_ids
|
||||
batch_size = target_ids.size(0)
|
||||
|
||||
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
|
||||
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
def data_collator(batch):
|
||||
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
speech_tokens.append(item["code"])
|
||||
message_list_item = []
|
||||
message_list_item += [
|
||||
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": item["text"]},
|
||||
]
|
||||
# message_list_item += [
|
||||
# {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"},
|
||||
# {"role": "assistant", "content": item["text"]},
|
||||
# ]
|
||||
messages.append(message_list_item)
|
||||
durations.append(item["duration"])
|
||||
ids.append(item["index"] if "index" in item else item["id"])
|
||||
lang.append(item["language"])
|
||||
|
||||
return {
|
||||
"speech_tokens": speech_tokens,
|
||||
"messages": messages,
|
||||
"durations": durations,
|
||||
"ids": ids,
|
||||
"lang": lang,
|
||||
}
|
||||
|
||||
def data_collator_ultra_chat(batch):
|
||||
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
speech_tokens.append(item["custom"]["speech_token"])
|
||||
text = item["supervisions"][0]["text"]
|
||||
message_list_item = []
|
||||
message_list_item += [
|
||||
{"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": text},
|
||||
]
|
||||
messages.append(message_list_item)
|
||||
durations.append(item["duration"])
|
||||
ids.append(item["id"])
|
||||
|
||||
return {
|
||||
"speech_tokens": speech_tokens,
|
||||
"messages": messages,
|
||||
"durations": durations,
|
||||
"ids": ids,
|
||||
}
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute the loss for the given batch.
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
tokenizer:
|
||||
The tokenizer used to encode the text.
|
||||
model:
|
||||
The model for training.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
is_training:
|
||||
Whether it is training.
|
||||
Returns:
|
||||
Return a tuple of two elements. The first element is the loss tensor.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"]
|
||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
||||
target_ids = target_ids.type(torch.LongTensor)
|
||||
input_ids = input_ids.type(torch.LongTensor)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
(
|
||||
text_loss,
|
||||
acc,
|
||||
codec_loss,
|
||||
codec_acc,
|
||||
codec_topk_acc,
|
||||
) = model.forward_with_speech_output(
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
speech_codec_ids=answer_cosyvoice_speech_token,
|
||||
)
|
||||
loss = text_loss + codec_loss
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = len(messages)
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["acc"] = acc * len(messages)
|
||||
info["codec_acc"] = codec_acc * len(messages)
|
||||
info["codec_topk_acc"] = codec_topk_acc * len(messages)
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["codec_loss"] = codec_loss.detach().cpu().item()
|
||||
info["text_loss"] = text_loss.detach().cpu().item()
|
||||
return loss, info
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
model: nn.Module,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
# FIX ME
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler, we call step() every step.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
# model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["durations"])
|
||||
if batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
# model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
if batch_idx != 0:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"zero-checkpoint-{params.batch_idx_train}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
|
||||
tag=f"zero-checkpoint-{params.batch_idx_train}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
# sampler_state_dict = train_dl.sampler.state_dict()
|
||||
sampler_state_dict = train_dl.state_dict()
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
|
||||
)
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
|
||||
)
|
||||
try:
|
||||
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
# deepspeed's backward() is different from torch's backward()
|
||||
# in that it does not accept a loss tensor as input.
|
||||
# It computes the loss internally.
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
|
||||
except: # noqa
|
||||
raise
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
try:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
except: # noqa
|
||||
cur_lr = 0.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
loss_value = tot_loss["loss"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.valid_interval = 2000
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
if rank == 0:
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info(params)
|
||||
logging.info("About to create model")
|
||||
model, tokenizer = get_model(params)
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", get_local_rank())
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
# assert params.deepspeed and world_size > 1
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=params, model=model, model_parameters=model.parameters()
|
||||
)
|
||||
|
||||
sampler_state_dict = None
|
||||
if params.sampler_state_dict_path:
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||
if params.dataset == "ultra_chat_voice_assistant":
|
||||
data_dir = "data/fbank"
|
||||
json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"]
|
||||
ds = load_dataset("json", data_files=json_file_lists, split="train")
|
||||
# shuffle the dataset
|
||||
train_dataset = ds.shuffle(seed=42)
|
||||
eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train")
|
||||
else:
|
||||
data_dir = Path(params.dataset)
|
||||
json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")]
|
||||
ds = load_dataset("json", data_files=json_file_lists, split="train")
|
||||
# shuffle the dataset
|
||||
ds = ds.shuffle(seed=42)
|
||||
train_test_split = ds.train_test_split(test_size=1000, seed=42)
|
||||
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
||||
|
||||
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
||||
train_dl = StatefulDataLoader(
|
||||
train_dataset,
|
||||
batch_size=params.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
prefetch_factor=2,
|
||||
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
|
||||
)
|
||||
train_dl.load_state_dict(sampler_state_dict)
|
||||
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
|
||||
valid_dl = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=params.batch_size,
|
||||
sampler=valid_sampler,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
prefetch_factor=1,
|
||||
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
|
||||
)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
logging.info(f"start training from epoch {params.start_epoch}")
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"zero-epoch-{params.cur_epoch}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}",
|
||||
tag=f"zero-epoch-{params.cur_epoch}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
# sampler_state_dict = train_dl.sampler.state_dict()
|
||||
sampler_state_dict = train_dl.state_dict()
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
|
||||
)
|
||||
|
||||
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = get_world_size()
|
||||
rank = get_rank()
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
433
egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py
Normal file
433
egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py
Normal file
@ -0,0 +1,433 @@
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
from tqdm import tqdm
|
||||
import kaldialign
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if "WORLD_SIZE" in os.environ:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_rank():
|
||||
if "RANK" in os.environ:
|
||||
return int(os.environ["RANK"])
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
return int(os.environ["LOCAL_RANK"])
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
return dist.get_local_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||
that a type is a bool type and user can enter
|
||||
|
||||
- yes, true, t, y, 1, to represent True
|
||||
- no, false, f, n, 0, to represent False
|
||||
|
||||
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
class AttributeDict(dict):
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
return self[key]
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key in self:
|
||||
del self[key]
|
||||
return
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
def __str__(self, indent: int = 2):
|
||||
tmp = {}
|
||||
for k, v in self.items():
|
||||
# PosixPath is ont JSON serializable
|
||||
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
|
||||
v = str(v)
|
||||
tmp[k] = v
|
||||
return json.dumps(tmp, indent=indent, sort_keys=True)
|
||||
|
||||
|
||||
def setup_logger(
|
||||
log_filename: Pathlike,
|
||||
log_level: str = "info",
|
||||
use_console: bool = True,
|
||||
) -> None:
|
||||
"""Setup log level.
|
||||
|
||||
Args:
|
||||
log_filename:
|
||||
The filename to save the log.
|
||||
log_level:
|
||||
The log level to use, e.g., "debug", "info", "warning", "error",
|
||||
"critical"
|
||||
use_console:
|
||||
True to also print logs to console.
|
||||
"""
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||
log_filename = f"{log_filename}-{date_time}-{rank}"
|
||||
else:
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
log_filename = f"{log_filename}-{date_time}"
|
||||
|
||||
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
||||
|
||||
level = logging.ERROR
|
||||
if log_level == "debug":
|
||||
level = logging.DEBUG
|
||||
elif log_level == "info":
|
||||
level = logging.INFO
|
||||
elif log_level == "warning":
|
||||
level = logging.WARNING
|
||||
elif log_level == "critical":
|
||||
level = logging.CRITICAL
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename,
|
||||
format=formatter,
|
||||
level=level,
|
||||
filemode="w",
|
||||
force=True,
|
||||
)
|
||||
if use_console:
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(level)
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
# makes undefined items default to int() which is zero.
|
||||
# This class will play a role as metrics tracker.
|
||||
# It can record many metrics, including but not limited to loss.
|
||||
super(MetricsTracker, self).__init__(int)
|
||||
|
||||
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v
|
||||
for k, v in other.items():
|
||||
if v - v == 0:
|
||||
ans[k] = ans[k] + v
|
||||
return ans
|
||||
|
||||
def __mul__(self, alpha: float) -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v * alpha
|
||||
return ans
|
||||
|
||||
def __str__(self) -> str:
|
||||
ans_frames = ""
|
||||
ans_utterances = ""
|
||||
for k, v in self.norm_items():
|
||||
norm_value = "%.4g" % v
|
||||
if "utt_" not in k:
|
||||
ans_frames += str(k) + "=" + str(norm_value) + ", "
|
||||
else:
|
||||
ans_utterances += str(k) + "=" + str(norm_value)
|
||||
if k == "utt_duration":
|
||||
ans_utterances += " frames, "
|
||||
elif k == "utt_pad_proportion":
|
||||
ans_utterances += ", "
|
||||
else:
|
||||
raise ValueError(f"Unexpected key: {k}")
|
||||
frames = "%.2f" % self["frames"]
|
||||
ans_frames += "over " + str(frames) + " frames. "
|
||||
if ans_utterances != "":
|
||||
utterances = "%.2f" % self["utterances"]
|
||||
ans_utterances += "over " + str(utterances) + " utterances."
|
||||
|
||||
return ans_frames + ans_utterances
|
||||
|
||||
def norm_items(self) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Returns a list of pairs, like:
|
||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||
"""
|
||||
num_frames = self["frames"] if "frames" in self else 1
|
||||
num_utterances = self["utterances"] if "utterances" in self else 1
|
||||
ans = []
|
||||
for k, v in self.items():
|
||||
if k == "frames" or k == "utterances":
|
||||
continue
|
||||
norm_value = (
|
||||
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
|
||||
)
|
||||
ans.append((k, norm_value))
|
||||
return ans
|
||||
|
||||
def reduce(self, device):
|
||||
"""
|
||||
Reduce using torch.distributed, which I believe ensures that
|
||||
all processes get the total.
|
||||
"""
|
||||
keys = sorted(self.keys())
|
||||
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
for k, v in zip(keys, s.cpu().tolist()):
|
||||
self[k] = v
|
||||
|
||||
def write_summary(
|
||||
self,
|
||||
tb_writer: SummaryWriter,
|
||||
prefix: str,
|
||||
batch_idx: int,
|
||||
) -> None:
|
||||
"""Add logging information to a TensorBoard writer.
|
||||
|
||||
Args:
|
||||
tb_writer: a TensorBoard writer
|
||||
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||
or "train/current_"
|
||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||
"""
|
||||
for k, v in self.norm_items():
|
||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
||||
strings.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf8") as f:
|
||||
for cut_id, ref, hyp in texts:
|
||||
if char_level:
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
compute_CER: bool = False,
|
||||
sclite_mode: bool = False,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cut_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
|
||||
if compute_CER:
|
||||
for i, res in enumerate(results):
|
||||
cut_id, ref, hyp = res
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
results[i] = (cut_id, ref, hyp)
|
||||
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
ref_len = sum([len(r) for _, r, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
434
egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py
Normal file
434
egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py
Normal file
@ -0,0 +1,434 @@
|
||||
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
|
||||
import io
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import gradio as gr
|
||||
import gradio.processing_utils as processing_utils
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import whisper
|
||||
#from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from gradio_client import utils as client_utils
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
def get_model(params, device="cuda"):
|
||||
"""Load and prepare the speech-to-speech model."""
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
)
|
||||
|
||||
encoder_projector = EncoderProjector(
|
||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||
)
|
||||
|
||||
# codec_vocab_size = 4096 + 4
|
||||
codec_vocab_size = 6561 + 4
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
intermediate_size=2048,
|
||||
max_position_embeddings=4096,
|
||||
)
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
encoder_projector,
|
||||
codec_lm,
|
||||
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
||||
)
|
||||
|
||||
checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu")
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
|
||||
"""
|
||||
Generate audio from tokens with optional tone and prompt embedding.
|
||||
|
||||
Args:
|
||||
audio_tokens (list): List of audio tokens to be processed.
|
||||
codec_decoder: Codec decoder for generating audio.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Generated audio waveform.
|
||||
"""
|
||||
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
|
||||
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
|
||||
prompt_speech_feat = torch.zeros(1, 0, 80)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
|
||||
prompt_token_len=torch.tensor(
|
||||
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
|
||||
prompt_feat_len=torch.tensor(
|
||||
[prompt_speech_feat.shape[1]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
embedding=flow_embedding.to(codec_decoder.model.device),
|
||||
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
|
||||
)
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer,
|
||||
):
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
chat_template=TEMPLATE,
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
def format_history(history: list):
|
||||
messages = []
|
||||
for item in history:
|
||||
if isinstance(item["content"], str):
|
||||
messages.append({"role": item["role"], "content": item["content"]})
|
||||
return messages
|
||||
|
||||
def decode(
|
||||
model,
|
||||
token2wav_model,
|
||||
tokenizer,
|
||||
feature,
|
||||
messages,
|
||||
):
|
||||
"""Decode one
|
||||
Returns:
|
||||
pass
|
||||
"""
|
||||
|
||||
dtype = torch.float32
|
||||
device = model.llm.device
|
||||
|
||||
feature = feature.to(device, dtype=dtype)
|
||||
|
||||
input_ids, attention_mask = preprocess([messages], tokenizer)
|
||||
|
||||
generated_ids, audio_tokens = model.decode_with_speech_output(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
yield {"type": "text", "data": hyps[0]}
|
||||
|
||||
audio_tokens = [token for token in audio_tokens if token < 4096]
|
||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
||||
audio = audio_hat.squeeze(0).cpu().numpy()
|
||||
audio = np.array(audio * 32767).astype(np.int16)
|
||||
wav_io = io.BytesIO()
|
||||
sf.write(wav_io, audio, samplerate=22050, format="WAV")
|
||||
wav_io.seek(0)
|
||||
wav_bytes = wav_io.getvalue()
|
||||
audio_path = processing_utils.save_bytes_to_cache(
|
||||
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
|
||||
)
|
||||
|
||||
yield {"type": "audio", "data": audio_path}
|
||||
|
||||
def media_predict(audio, history):
|
||||
# First yield
|
||||
yield (
|
||||
None, # microphone
|
||||
history, # media_chatbot
|
||||
gr.update(visible=False), # submit_btn
|
||||
gr.update(visible=True), # stop_btn
|
||||
)
|
||||
print(2333, history, audio)
|
||||
history.append({"role": "user", "content": (audio,)})
|
||||
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
formatted_history = format_history(
|
||||
history=history
|
||||
) # only keep string text format
|
||||
|
||||
assert audio is not None
|
||||
audio_transcript = get_transcript(
|
||||
audio,
|
||||
asr_model,
|
||||
)
|
||||
history[-2]["content"] = audio_transcript
|
||||
|
||||
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
|
||||
fbank = fbank.unsqueeze(0)
|
||||
assert fbank.ndim == 3
|
||||
|
||||
for chunk in decode(
|
||||
model, token2wav_model, tokenizer, fbank, formatted_history
|
||||
):
|
||||
if chunk["type"] == "text":
|
||||
history[-1]["content"] = chunk["data"]
|
||||
yield (
|
||||
None, # microphone
|
||||
history, # media_chatbot
|
||||
gr.update(visible=False), # submit_btn
|
||||
gr.update(visible=True), # stop_btn
|
||||
)
|
||||
if chunk["type"] == "audio":
|
||||
history.append(
|
||||
{"role": "assistant", "content": gr.Audio(chunk["data"])}
|
||||
)
|
||||
|
||||
# Final yield
|
||||
yield (
|
||||
None, # microphone
|
||||
history, # media_chatbot
|
||||
gr.update(visible=True), # submit_btn
|
||||
gr.update(visible=False), # stop_btn
|
||||
)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tab("Online"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
microphone = gr.Audio(sources=["microphone"], type="filepath")
|
||||
submit_btn = gr.Button("Submit", variant="primary")
|
||||
stop_btn = gr.Button("Stop", visible=False)
|
||||
clear_btn = gr.Button("Clear History")
|
||||
with gr.Column(scale=2):
|
||||
media_chatbot = gr.Chatbot(height=650, type="messages")
|
||||
|
||||
def clear_history():
|
||||
return [], gr.update(value=None)
|
||||
|
||||
submit_event = submit_btn.click(
|
||||
fn=media_predict,
|
||||
inputs=[
|
||||
microphone,
|
||||
media_chatbot,
|
||||
],
|
||||
outputs=[microphone, media_chatbot, submit_btn, stop_btn],
|
||||
)
|
||||
stop_btn.click(
|
||||
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
|
||||
inputs=None,
|
||||
outputs=[submit_btn, stop_btn],
|
||||
cancels=[submit_event],
|
||||
queue=False,
|
||||
)
|
||||
clear_btn.click(
|
||||
fn=clear_history, inputs=None, outputs=[media_chatbot, microphone]
|
||||
)
|
||||
|
||||
demo.queue(default_concurrency_limit=100, max_size=100).launch(
|
||||
max_threads=100,
|
||||
ssr_mode=False,
|
||||
share=args.share,
|
||||
inbrowser=args.inbrowser,
|
||||
server_port=args.server_port,
|
||||
server_name=args.server_name,
|
||||
)
|
||||
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint name or path, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Token2Wav path, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr-model-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="ASR model dir, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash-attn2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable flash_attention_2 when loading the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--share",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Create a publicly shareable link for the interface.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inbrowser",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Automatically launch the interface in a new tab on the default browser.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-port", type=int, default=8001, help="Demo server port."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def read_wave(wave_filename: str):
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and can be of type
|
||||
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples,
|
||||
which are normalized to the range [-1, 1].
|
||||
- Sample rate of the wave file.
|
||||
"""
|
||||
|
||||
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||
assert (
|
||||
samples.ndim == 1
|
||||
), f"Expected single channel, but got {samples.ndim} channels."
|
||||
|
||||
samples_float32 = samples.astype(np.float32)
|
||||
|
||||
return samples_float32, sample_rate
|
||||
|
||||
|
||||
def get_transcript(audio_path, recognizer):
|
||||
samples, sample_rate = read_wave(audio_path)
|
||||
s = recognizer.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
recognizer.decode_streams([s])
|
||||
return s.result.text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _get_args()
|
||||
model, tokenizer = get_model(args)
|
||||
token2wav = CosyVoice(
|
||||
args.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
|
||||
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
|
||||
tokens=f"{args.asr_model_dir}/tokens.txt",
|
||||
num_threads=2,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
decoding_method="greedy_search",
|
||||
debug=False,
|
||||
)
|
||||
|
||||
_launch_demo(args, model, tokenizer, token2wav, asr_model)
|
@ -0,0 +1 @@
|
||||
../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
|
Loading…
x
Reference in New Issue
Block a user