clean code

This commit is contained in:
root 2025-03-03 05:40:38 +00:00
parent d2b473ad99
commit 7623939fbf
12 changed files with 207 additions and 972 deletions

94
egs/emilia/TTS/README.md Normal file
View File

@ -0,0 +1,94 @@
# Results
| LLM Model | Flow matching Model | Seed-TTS test_zh CER | Comment |
|---------------------------------------|----------|-----------|--------|
| pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)|
| pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)|
| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.89% (16 steps) | |
# Introduction
[**Emilia**](https://huggingface.co/datasets/amphion/Emilia-Dataset) starts with over 101k
hours of speech across six languages, covering a wide range of speaking styles to enable more natural and spontaneous speech generation.
See https://arxiv.org/pdf/2407.05361.
> [!CAUTION]
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities.
>
> By using this framework, you agree to the following:
> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data.
>
> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology.
>
> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required.
>
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
# Llasa (cosyvoice2 token)
./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens.
Generated samples and training logs of [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) 50k hours Chinese data can be found [here](https://huggingface.co/yuekai/llasa_cosyvoice2_token_qwen_0.5b/tree/main).
Preparation:
```
# extract cosyvoice2 semantic tokens
bash prepare.sh --stage 3 --stop_stage 4
# Or you could use the prepared tokens.
huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token
```
The training command is given below:
```
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install -r llasa_cosyvoice2_token/requirements.txt
# pip install -r icefall/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt
WANDB_KEY=$your_wandb_key
wandb login ${WANDB_KEY}
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node=8 train.py config.json
```
To inference with Icefall Emilia trained Chinese Llasa_cosyvoice2_token model, we need to use cosyvoice2 token flow matching [model](https://github.com/k2-fsa/icefall/pull/1880):
```
cd icefall/egs/wenetspeech4tts/TTS
huggingface-cli login
huggingface-cli download --local-dir ${exp_dir} yuekai/llasa_cosyvoice2_token_qwen_0.5b
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
vocoder=./bigvgan_v2_24khz_100band_256x
split=test_zh
llm_path=llasa_cosyvoice2_token_qwen_0.5b/checkpoint-800000
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
torchrun --nproc_per_node=2 \
f5-tts/infer_dist.py \
--output_dir $output_dir \
--batch_size 1 \
--num_workers 2 \
--llm-model-name-or-path $llm_path \
--flow-matching-model-path $model_path \
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--use-cosyvoice-semantic-token True \
--vocoder-dir $vocoder \
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
# compute cer
huggingface-cli download yuekai/seed_tts_eval --local-dir seed_tts_eval --repo-type dataset
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
bash local/compute_wer.sh $output_dir $manifest
```
# Credits
- [Llasa](https://arxiv.org/abs/2502.04128)
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main)

View File

@ -1,6 +1,6 @@
{
"llm_model_name_or_path": "/workspace/slam/icefall_omni/egs/speech_llm/SPEECH2SPEECH/models/Qwen2.5-0.5B-Instruct",
"data_path": ["../emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"],
"llm_model_name_or_path": "./Qwen2.5-0.5B-Instruct",
"data_path": ["./emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"],
"bf16": false,
"output_dir": "./exp_zh",
"num_train_epochs": 3,

View File

@ -5,3 +5,4 @@ datasets
accelerate>=0.26.0
deepspeed
flash-attn
s3tokenizer

View File

@ -1,3 +1,11 @@
# Modified from https://github.com/zhenye234/LLaSA_training/blob/main/train_tts.py
""" Example Usage
WANDB_KEY=$your_wandb_key
wandb login ${WANDB_KEY}
huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node=8 train.py config.json
"""
import json
import os
import random
@ -11,8 +19,7 @@ import torch
import torch.nn as nn
import transformers
import wandb
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@ -65,7 +72,7 @@ class CustomTrainingArguments(TrainingArguments):
remove_unused_columns: bool = field(default=False)
def data_collator(batch, tokenizer):
def data_collator(batch, tokenizer, original_tokenizer_vocab_size, cut_off_len=2048):
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
"<|SPEECH_GENERATION_START|>"
)
@ -84,11 +91,11 @@ def data_collator(batch, tokenizer):
chat_template=TEMPLATE,
)
code = [c + 151665 for c in code]
code = [c + original_tokenizer_vocab_size for c in code]
idx = input_ids.index(speech_generation_start_index)
input_ids = input_ids[:idx] + code + input_ids[idx + 1 :]
if len(input_ids) < 2048:
if len(input_ids) < cut_off_len:
input_ids_list.append(input_ids)
max_len = max([len(input_ids) for input_ids in input_ids_list])
@ -140,7 +147,11 @@ def main():
)
tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path)
new_tokens = [f"<|s_{i}|>" for i in range(6561)] + ["<|SPEECH_GENERATION_START|>"]
original_tokenizer_vocab_size = len(tokenizer)
cosyvoice2_token_size = 6561
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
"<|SPEECH_GENERATION_START|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))
@ -157,7 +168,9 @@ def main():
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=lambda features: data_collator(features, tokenizer),
data_collator=lambda features: data_collator(
features, tokenizer, original_tokenizer_vocab_size
),
)
if is_main_process:

View File

@ -1,4 +1,5 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,21 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
cpu:
s3tokenizer --data_dir xxx.scp \
--device "cpu" \
--output_dir "./" \
--batch_size 32
gpu:
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
`which s3tokenizer` --data_dir xxx.scp \
local/extract_cosyvoice2_token.py --data_dir $data_dir \
--jsonl_file $jsonl_file_basename \
--device "cuda" \
--output_dir "./" \
--batch_size 32
--output_dir $output_dir \
--batch_size 32 \
--num_workers 2 \
--model "speech_tokenizer_v2_25hz"
"""

View File

@ -4,16 +4,17 @@ set -eou pipefail
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
# pip install lhotse s3tokenizer
stage=6
stop_stage=6
stage=3
stop_stage=4
# Please download the OpenDataLab format from HuggingFace, you can specify the revision argument to fc71e07e8572f5f3be1dbd02ed3172a4d298f152, which is the old format.
# https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07e8572f5f3be1dbd02ed3172a4d298f152
dl_dir=$PWD/download
dl_dir=/workspace_data/Emilia-Dataset/
prefix="emilia"
# zh, en, ja, ko, de, fr
lang_set=("de" "en" "zh" "ja" "ko" "fr")
lang_set=("de" "en" "zh" "ja" "fr")
. shared/parse_options.sh || exit 1
@ -29,23 +30,20 @@ log() {
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "dl_dir: $dl_dir"
log "Stage 0: Download data"
#huggingface-cli login
# huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS
# Extract the downloaded data:
cat $dl_dir/raw/EN/EN_B00008.tar.gz.* > $dl_dir/raw/EN/EN_B00008.tar.gz
for lang in "${lang_set[@]}"; do
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
folder=$dl_dir/raw/${lang_upper}
for file in $folder/*.tar.gz; do
echo "Processing ${file}"
# e.g. $dl_dir/raw/DE/*tar.gz untar first, DE is the language code in upper case
tar -xzvf $file -C $folder
done
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare emilia manifest"
log "Stage 1: Prepare emilia manifest (used by ./f5-tts)"
# We assume that you have downloaded the Emilia corpus
# to $dl_dir/emilia
mkdir -p data/manifests
@ -58,7 +56,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Generate fbank (used by ./f5-tts)"
mkdir -p data/fbank
@ -71,67 +68,8 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
done
fi
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
log "Stage 6: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
echo "Combining ${prefix} cuts"
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz
fi
if [ ! -e data/fbank/.${prefix}_split.done ]; then
echo "Splitting ${prefix} cuts into train, valid and test sets"
lhotse subset --last 800 \
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
data/fbank/${prefix}_cuts_validtest.jsonl.gz
lhotse subset --first 400 \
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
data/fbank/${prefix}_cuts_valid.jsonl.gz
lhotse subset --last 400 \
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
data/fbank/${prefix}_cuts_test.jsonl.gz
rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
lhotse subset --first $n \
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
data/fbank/${prefix}_cuts_train.jsonl.gz
touch data/fbank/.${prefix}_split.done
fi
fi
# zcat test.jsonl.gz | jq -r '.recording.id + " " + .recording.sources[0].source' > wav.scp
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)"
data_dir=$dl_dir/raw/ZH
# for all jsonl files in data_dir
for jsonl_file in $data_dir/*.jsonl; do
# get the file basename
jsonl_file_basename=$(basename $jsonl_file)
echo "Processing $jsonl_file"
output_dir="./cosy_v2_tokens_ZH/${jsonl_file_basename%.jsonl}"
echo "output_dir: $output_dir"
# skip if the output_dir exists
if [ -e $output_dir ]; then
echo "Output directory $output_dir already exists, skipping"
continue
fi
mkdir -p $output_dir
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
local/extract_cosyvoice2_token.py --data_dir $data_dir \
--jsonl_file $jsonl_file_basename \
--device "cuda" \
--output_dir $output_dir \
--batch_size 32 \
--num_workers 2 \
--model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)"
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Extract cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)"
for lang in "${lang_set[@]}"; do
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
data_dir=$dl_dir/raw/${lang_upper}
@ -161,14 +99,13 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
done
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# cat EN_B00008.tar.gz.* > EN_B00008.tar.gz
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Merge cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)"
for lang in "${lang_set[@]}"; do
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
cosy_token_dir="./cosy_v2_tokens_${lang_upper}"
for dir in $cosy_token_dir/*; do
echo "Processing $dir"
# get the file basename
dir_basename=$(basename $dir)
echo "dir_basename: $dir_basename"
cat $dir/part* > $dir/${dir_basename}.jsonl

View File

@ -186,3 +186,5 @@ bash local/compute_wer.sh $output_dir $manifest
- [VALL-E](https://github.com/lifeiteng/vall-e)
- [F5-TTS](https://github.com/SWivid/F5-TTS)
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main)
- [Spark-TTS](https://github.com/SparkAudio/Spark-TTS)

View File

@ -108,13 +108,6 @@ def get_parser():
help="Interpolate semantic token to match mel frames for CosyVoice",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
parser.add_argument(
"--split-name",
type=str,

View File

@ -1,4 +1,5 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,23 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
cpu:
s3tokenizer --data_dir xxx.scp \
--device "cpu" \
--output_dir "./" \
--batch_size 32
gpu:
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
`which s3tokenizer` --data_dir xxx.scp \
--device "cuda" \
--output_dir "./" \
--batch_size 32
split=test_zh
llm_path=f5-tts/exp_zh/checkpoint-805000
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
vocoder=./bigvgan_v2_24khz_100band_256x
torchrun --nproc_per_node=2 \
f5-tts/infer_dist.py \
--output_dir $output_dir \
--batch_size 1 \
--num_workers 2 \
--llm-model-name-or-path $llm_path \
--flow-matching-model-path $model_path \
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--use-cosyvoice-semantic-token True \
--vocoder-dir $vocoder \
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
"""
import argparse
@ -81,16 +85,16 @@ def get_args():
help="huggingface dataset split name",
)
parser.add_argument(
"--output_dir", required=True, type=str, help="dir to save result"
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch_size",
"--batch-size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num_workers", type=int, default=4, help="workers for dataloader"
"--num-workers", type=int, default=4, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=5, help="prefetch for dataloader"
@ -119,6 +123,24 @@ def get_args():
type=str,
help="flow matching model path",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
add_model_arguments(parser)
args = parser.parse_args()
return args
@ -285,7 +307,11 @@ def main():
for batch in dataloader:
generate_codes = model.inference_batch(
batch["input_ids"], batch["attention_mask"]
batch["input_ids"],
batch["attention_mask"],
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
)
flow_matching_input_tokens, total_mel_lens = [], []
for i, code in enumerate(generate_codes):

View File

@ -1,828 +0,0 @@
#!/usr/bin/env python3
# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py
"""
Usage:
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst
python3 f5-tts/generate_averaged_model.py \
--epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small
# command for text token input
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
# command for cosyvoice semantic token input
split=test_zh # seed_tts_eval test_zh
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
bash local/compute_wer.sh $output_dir $manifest
"""
import argparse
import logging
import math
import os
import random
import time
from pathlib import Path
import datasets
import torch
import torch.nn.functional as F
import torchaudio
from accelerate import Accelerator
from bigvganinference import BigVGANInference
from model.cfm import CFM
from model.dit import DiT
from model.modules import MelSpec
from model.utils import convert_char_to_pinyin
from tqdm import tqdm
from train import (
add_model_arguments,
get_model,
get_tokenizer,
interpolate_tokens,
load_F5_TTS_pretrained_checkpoint,
)
from icefall.checkpoint import load_checkpoint
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
default="f5-tts/vocab.txt",
help="Path to the unique text tokens file",
)
parser.add_argument(
"--model-path",
type=str,
default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
help="Path to the unique text tokens file",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--nfe",
type=int,
default=16,
help="The number of steps for the neural ODE",
)
parser.add_argument(
"--manifest-file",
type=str,
default=None,
help="The manifest file in seed_tts_eval format",
)
parser.add_argument(
"--output-dir",
type=Path,
default="results",
help="The output directory to save the generated wavs",
)
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument(
"--interpolate-token",
type=str2bool,
default=True,
help="Interpolate semantic token to match mel frames for CosyVoice",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
add_model_arguments(parser)
return parser.parse_args()
def get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
):
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
metainfo, desc="Processing prompts..."
):
# Audio
ref_audio, ref_sr = torchaudio.load(prompt_wav)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
if ref_rms < target_rms:
ref_audio = ref_audio * target_rms / ref_rms
assert (
ref_audio.shape[-1] > 5000
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio)
# Text
if len(prompt_text[-1].encode("utf-8")) == 1:
prompt_text = prompt_text + " "
text = [prompt_text + gt_text]
if tokenizer == "pinyin":
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
else:
text_list = text
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
gt_audio = resampler(gt_audio)
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
ref_text_len = len(prompt_text.encode("utf-8"))
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(
ref_mel_len / ref_text_len * gen_text_len / speed
)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor(
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
final_text_list[bucket_i].extend(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = (
[],
[],
[],
[],
[],
[],
)
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
def get_inference_prompt_cosy_voice_huggingface(
dataset,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
interpolate_token=False,
):
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for i in range(len(dataset)):
utt = dataset[i]["id"]
ref_audio_org, ref_sr = (
dataset[i]["prompt_audio"]["array"],
dataset[i]["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
audio_tokens = dataset[i]["target_audio_cosy2_tokens"]
prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"]
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
input_tokens = prompt_audio_tokens + audio_tokens
if interpolate_token:
input_tokens = interpolate_tokens(input_tokens)
text_list = input_tokens
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
total_mel_len = len(input_tokens)
if not interpolate_token:
total_mel_len = int(total_mel_len / 4 * 15)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
if total_mel_len > max_tokens:
print(
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
)
continue
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor(
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
# final_text_list[bucket_i].extend(text_list)
final_text_list[bucket_i].append(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = (
[],
[],
[],
[],
[],
[],
)
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
def inference_speech_token(
cosyvoice,
tts_text,
prompt_text,
prompt_speech_16k,
stream=False,
speed=1.0,
text_frontend=True,
):
tokens = []
prompt_text = cosyvoice.frontend.text_normalize(
prompt_text, split=False, text_frontend=text_frontend
)
for i in cosyvoice.frontend.text_normalize(
tts_text, split=True, text_frontend=text_frontend
):
tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i)
(
prompt_text_token,
prompt_text_token_len,
) = cosyvoice.frontend._extract_text_token(prompt_text)
speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token(
prompt_speech_16k
)
for i in cosyvoice.model.llm.inference(
text=tts_text_token.to(cosyvoice.model.device),
text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(
cosyvoice.model.device
),
prompt_text=prompt_text_token.to(cosyvoice.model.device),
prompt_text_len=torch.tensor(
[prompt_text_token.shape[1]], dtype=torch.int32
).to(cosyvoice.model.device),
prompt_speech_token=speech_token.to(cosyvoice.model.device),
prompt_speech_token_len=torch.tensor(
[speech_token.shape[1]], dtype=torch.int32
).to(cosyvoice.model.device),
embedding=None,
):
tokens.append(i)
return tokens, speech_token
def get_inference_prompt_cosy_voice(
metainfo,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
interpolate_token=False,
):
import sys
# please change the path to the cosyvoice accordingly
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
sys.path.append("/workspace/CosyVoice")
from cosyvoice.cli.cosyvoice import CosyVoice2
# please download the cosyvoice model first
cosyvoice = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
)
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
metainfo, desc="Processing prompts..."
):
# Audio
ref_audio_org, ref_sr = torchaudio.load(prompt_wav)
# cosy voice
if ref_sr != 16000:
resampler = torchaudio.transforms.Resample(ref_sr, 16000)
ref_audio_16k = resampler(ref_audio_org)
else:
ref_audio_16k = ref_audio_org
audio_tokens, prompt_audio_tokens = inference_speech_token(
cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False
)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
assert (
ref_audio_org.shape[-1] > 5000
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
# Text
# if len(prompt_text[-1].encode("utf-8")) == 1:
# prompt_text = prompt_text + " "
# text = [prompt_text + gt_text]
# if tokenizer == "pinyin":
# text_list = convert_char_to_pinyin(text, polyphone=polyphone)
# else:
# text_list = text
# concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens
# prompt_audio_tokens shape 1, prompt_audio_tokens
# audio_tokens shape 1, audio_tokens
prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist()
input_tokens = prompt_audio_tokens + audio_tokens
# convert it into a list
# input_tokens_list = input_tokens.squeeze().cpu().tolist()
if interpolate_token:
input_tokens = interpolate_tokens(input_tokens)
text_list = input_tokens
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
gt_audio = resampler(gt_audio)
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
ref_text_len = len(prompt_text.encode("utf-8"))
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len_compute = ref_mel_len + int(
ref_mel_len / ref_text_len * gen_text_len / speed
)
total_mel_len = len(input_tokens)
if not interpolate_token:
total_mel_len = int(total_mel_len / 4 * 15)
print(
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor(
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
# final_text_list[bucket_i].extend(text_list)
final_text_list[bucket_i].append(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = (
[],
[],
[],
[],
[],
[],
)
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
return padded_ref_mels
def get_seedtts_testset_metainfo(metalst):
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
return metainfo
def main():
args = get_parser()
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
if args.manifest_file:
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
if not args.use_cosyvoice_semantic_token:
prompts_all = get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
)
else:
prompts_all = get_inference_prompt_cosy_voice(
metainfo,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
interpolate_token=args.interpolate_token,
)
else:
assert args.use_cosyvoice_semantic_token
dataset = datasets.load_dataset(
"yuekai/seed_tts_cosy2",
split=args.split_name,
trust_remote_code=True,
)
prompts_all = get_inference_prompt_cosy_voice_huggingface(
dataset,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
interpolate_token=args.interpolate_token,
)
vocoder = BigVGANInference.from_pretrained(
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
)
vocoder = vocoder.eval().to(device)
model = get_model(args).eval().to(device)
checkpoint = torch.load(args.model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
else:
_ = load_checkpoint(
args.model_path,
model=model,
)
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
(
utts,
ref_rms_list,
ref_mels,
ref_mel_lens,
total_mel_lens,
final_text_list,
) = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
if args.use_cosyvoice_semantic_token:
# concat final_text_list
max_len = max([len(tokens) for tokens in final_text_list])
# pad tokens to the same length
for i, tokens in enumerate(final_text_list):
final_text_list[i] = torch.tensor(
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
)
final_text_list = torch.stack(final_text_list).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=args.nfe,
cfg_strength=2.0,
sway_sampling_coef=args.swaysampling,
no_ref_audio=False,
seed=args.seed,
)
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(
f"{args.output_dir}/{utts[i]}.wav",
generated_wave,
target_sample_rate,
)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,5 +1,6 @@
# Copyright (c) 2025 SparkAudio
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
# 2025 Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py
# Modified from https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py
import re
from pathlib import Path
@ -39,7 +40,9 @@ class LLMTTS:
Args:
model_dir (Path): Directory containing the model and config files.
device (torch.device): The device (CPU/GPU) to run the model on.
tokenizer_dir (Path): Directory containing the tokenizer files.
s3_tokenizer_name (str): Name of the tokenizer file on S3.
device (torch.device): Device to run the model on.
"""
self.device = device
@ -51,7 +54,9 @@ class LLMTTS:
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
new_tokens = [f"<|s_{i}|>" for i in range(6561)] + [
self.original_vocab_size = len(tokenizer)
self.cosyvoice2_token_vocab_size = 6561
new_tokens = [f"<|s_{i}|>" for i in range(self.cosyvoice2_token_vocab_size)] + [
"<|SPEECH_GENERATION_START|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)
@ -67,42 +72,39 @@ class LLMTTS:
temperature: float = 0.8,
top_k: float = 50,
top_p: float = 0.95,
max_new_tokens: int = 1024,
) -> torch.Tensor:
"""
Performs inference to generate speech from text, incorporating prompt audio and/or text.
Args:
text (str): The text input to be converted to speech.
prompt_speech_path (Path): Path to the audio file used as a prompt.
prompt_text (str, optional): Transcript of the prompt audio.
gender (str): female | male.
pitch (str): very_low | low | moderate | high | very_high
speed (str): very_low | low | moderate | high | very_high
input_ids (torch.Tensor): Input IDs for the model.
attention_mask (torch.Tensor): Attention mask for the model.
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
top_k (float, optional): Top-k sampling parameter. Default is 50.
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
max_new_tokens (int, optional): Maximum number of tokens to generate. Default is 1024.
Returns:
torch.Tensor: Generated waveform as a tensor.
"""
# Generate speech using the model
generated_ids = self.model.generate(
input_ids=input_ids.to(self.device),
attention_mask=attention_mask.to(self.device),
max_new_tokens=1024,
max_new_tokens=max_new_tokens,
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)
results = []
generated_ids = generated_ids.cpu().tolist()
for i in range(len(generated_ids)):
assistant_index = generated_ids[i].index(self.assistant_index)
padding_index = len(generated_ids[i])
# WAR: harding coding assistant_index + 2, for the current template Assistant: \n
result = generated_ids[i][assistant_index + 2 :]
result = [token - 151665 for token in result]
result = [token - self.original_vocab_size for token in result]
result = [token for token in result if token >= 0]
results.append(result)
return results

View File

@ -118,6 +118,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of Decoder layers.",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -313,13 +320,6 @@ def get_parser():
help="perform OOM check on dataloader batches before starting training.",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
add_model_arguments(parser)
return parser