Merge 1653b76deb5bffec958d17cf5440ace4f776732f into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Yuekai Zhang 2025-07-19 14:37:36 +08:00 committed by GitHub
commit 5c73a91e72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1150 additions and 28 deletions

76
egs/emilia/TTS/README.md Normal file
View File

@ -0,0 +1,76 @@
# Results
| LLM Model | Flow matching Model | Seed-TTS test_zh CER | Comment |
|---------------------------------------|----------|-----------|--------|
| pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)|
| pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)|
| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.81% (16 steps) | |
# Introduction
[**Emilia**](https://huggingface.co/datasets/amphion/Emilia-Dataset) starts with over 101k
hours of speech across six languages, covering a wide range of speaking styles to enable more natural and spontaneous speech generation.
See https://arxiv.org/pdf/2407.05361.
# Llasa (cosyvoice2 token)
./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens.
Generated samples and training logs of [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) 50k hours Chinese data can be found [here](https://huggingface.co/yuekai/llasa_cosyvoice2_token_qwen_0.5b/tree/main).
Preparation:
```
# extract cosyvoice2 semantic tokens
bash prepare.sh --stage 3 --stop_stage 4
# Or you could use the prepared tokens.
huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token
```
The training command is given below:
```
# docker: ghcr.io/swivid/f5-tts:main
# pip install -r llasa_cosyvoice2_token/requirements.txt
WANDB_KEY=$your_wandb_key
wandb login ${WANDB_KEY}
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node=8 train.py config.json
```
To inference with Icefall Emilia trained Chinese Llasa_cosyvoice2_token model, we need to use cosyvoice2 token flow matching [model](https://github.com/k2-fsa/icefall/pull/1880):
```
cd icefall/egs/wenetspeech4tts/TTS
huggingface-cli login
huggingface-cli download --local-dir ${exp_dir} yuekai/llasa_cosyvoice2_token_qwen_0.5b
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
vocoder=./bigvgan_v2_24khz_100band_256x
split=test_zh
llm_path=llasa_cosyvoice2_token_qwen_0.5b/checkpoint-800000
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
torchrun --nproc_per_node=2 \
f5-tts/infer_dist.py \
--output_dir $output_dir \
--batch_size 1 \
--num_workers 2 \
--llm-model-name-or-path $llm_path \
--flow-matching-model-path $model_path \
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--use-cosyvoice-semantic-token True \
--vocoder-dir $vocoder \
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
# compute cer
huggingface-cli download yuekai/seed_tts_eval --local-dir seed_tts_eval --repo-type dataset
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
bash local/compute_wer.sh $output_dir $manifest
```
# Credits
- [Llasa](https://arxiv.org/abs/2502.04128)
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main)

View File

@ -0,0 +1,27 @@
{
"llm_model_name_or_path": "./Qwen2.5-0.5B-Instruct",
"data_path": ["./emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"],
"bf16": false,
"output_dir": "./exp_zh",
"num_train_epochs": 3,
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"gradient_accumulation_steps": 1,
"evaluation_strategy": "steps",
"eval_steps": 1000,
"save_strategy": "steps",
"save_steps": 5000,
"save_total_limit": 100,
"learning_rate": 0.00005,
"weight_decay": 0.01,
"adam_beta2": 0.95,
"warmup_ratio": 0.03,
"lr_scheduler_type": "cosine",
"logging_steps": 100,
"report_to": "wandb",
"model_max_length": 2048,
"gradient_checkpointing": false,
"dataloader_num_workers": 4,
"dataloader_prefetch_factor": 4,
"deepspeed": "ds_config_zero2.json"
}

View File

@ -0,0 +1,47 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 64,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupCosineLR",
"params": {
"total_num_steps": "auto",
"warmup_min_ratio": 0.03,
"warmup_num_steps": "auto",
"cos_min_ratio": 0.1
}
},
"zero_optimization": {
"stage": 2,
"overlap_comm": false,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,8 @@
torch
transformers
wandb
datasets
accelerate>=0.26.0
deepspeed
flash-attn
s3tokenizer

View File

@ -0,0 +1,184 @@
# Modified from https://github.com/zhenye234/LLaSA_training/blob/main/train_tts.py
""" Example Usage
WANDB_KEY=$your_wandb_key
wandb login ${WANDB_KEY}
huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node=8 train.py config.json
"""
import json
import os
import random
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import transformers
import wandb
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Trainer,
TrainingArguments,
)
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
@dataclass
class ModelArguments:
llm_model_name_or_path: Optional[str] = field(
default="meta-llama/Llama-3.2-1B-Instruct"
)
@dataclass
class DataArguments:
data_path: List[str] = field(
default=None,
metadata={"help": "Root path(s) to the data. Can be single path or list."},
)
@dataclass
class CustomTrainingArguments(TrainingArguments):
optim: str = field(default="adamw_torch_fused")
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length"},
)
logging_steps: int = field(default=100, metadata={"help": "Log every X updates"})
report_to: Optional[str] = field(
default=None,
metadata={"help": "The integration to report the results and logs to."},
)
run_name: Optional[str] = field(
default=None, metadata={"help": "The name of the run for logging."}
)
gradient_checkpointing: bool = field(default=False)
lr_scheduler_type: str = field(
default="cosine", metadata={"help": "The learning rate scheduler to use."}
)
remove_unused_columns: bool = field(default=False)
def data_collator(batch, tokenizer, original_tokenizer_vocab_size, cut_off_len=2048):
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
"<|SPEECH_GENERATION_START|>"
)
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
input_ids_list = []
for i, item in enumerate(batch):
text, code = item["text"], item["code"]
message = [
{"role": "user", "content": f"Convert the text to speech: {text}"},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
]
input_ids = tokenizer.apply_chat_template(
message,
tokenize=True,
chat_template=TEMPLATE,
)
code = [c + original_tokenizer_vocab_size for c in code]
idx = input_ids.index(speech_generation_start_index)
input_ids = input_ids[:idx] + code + input_ids[idx + 1 :]
if len(input_ids) < cut_off_len:
input_ids_list.append(input_ids)
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [
input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
for input_ids in input_ids_list
]
input_ids = torch.tensor(input_ids_list, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
mask_indices = torch.where(input_ids == assistant_index)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": target_ids.to(dtype=torch.int64),
}
def main():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, CustomTrainingArguments)
)
assert len(sys.argv) == 2 and sys.argv[1].endswith(".json")
(
model_args,
data_args,
training_args,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
is_main_process = training_args.local_rank in [-1, 0]
if training_args.report_to == "wandb" and is_main_process:
wandb.init(
project="llm_tts",
config=training_args.to_sanitized_dict(),
name=training_args.run_name,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.llm_model_name_or_path,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path)
original_tokenizer_vocab_size = len(tokenizer)
cosyvoice2_token_size = 6561
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
"<|SPEECH_GENERATION_START|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))
model.vocab_size = len(tokenizer)
dataset = load_dataset("json", data_files=data_args.data_path)
dataset = dataset["train"]
train_test_split = dataset.train_test_split(test_size=100, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=lambda features: data_collator(
features, tokenizer, original_tokenizer_vocab_size
),
)
if is_main_process:
trainer.add_callback(transformers.integrations.WandbCallback())
trainer.train(resume_from_checkpoint=None)
trainer.save_model(training_args.output_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,200 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
local/extract_cosyvoice2_token.py --data_dir $data_dir \
--jsonl_file $jsonl_file_basename \
--device "cuda" \
--output_dir $output_dir \
--batch_size 32 \
--num_workers 2 \
--model "speech_tokenizer_v2_25hz"
"""
import argparse
import json
import os
from pathlib import Path
import s3tokenizer
import torch
import torch.distributed as dist
from lhotse.serialization import load_jsonl
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
class AudioDataset(Dataset):
def __init__(self, data_dir, jsonl_file):
self.data = []
# convert data_dir to Path object
self.data_dir = Path(data_dir)
# jsonl_files = self.data_dir.glob("*.jsonl")
jsonl_files = [self.data_dir / jsonl_file]
for jsonl_file in jsonl_files:
for item in tqdm(
# Note: People's Speech manifest.json is really a JSONL.
load_jsonl(jsonl_file),
desc=f"Processing {jsonl_file}",
):
self.data.append(item)
break
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
file_path = self.data_dir / self.data[idx]["wav"]
audio = s3tokenizer.load_audio(file_path)
if audio.shape[0] / 16000 > 30:
print(
f"do not support extract speech token for audio longer than 30s, file_path: {file_path}" # noqa
)
mel = torch.zeros(128, 0)
else:
mel = s3tokenizer.log_mel_spectrogram(audio)
return self.data[idx], mel
def collate_fn(batch):
keys = [item[0] for item in batch]
mels = [item[1] for item in batch]
mels, mels_lens = s3tokenizer.padding(mels)
return keys, mels, mels_lens
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--model",
required=True,
type=str,
choices=[
"speech_tokenizer_v1",
"speech_tokenizer_v1_25hz",
"speech_tokenizer_v2_25hz",
],
help="model version",
)
parser.add_argument(
"--data_dir",
required=True,
type=str,
help="each line contains `wav_name wav_path`",
)
parser.add_argument(
"--jsonl_file",
required=True,
type=str,
help="each line contains `wav_name wav_path`",
)
parser.add_argument(
"--device",
required=True,
type=str,
choices=["cuda", "cpu"],
help="device for inference",
)
parser.add_argument(
"--output_dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch_size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num_workers", type=int, default=4, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=5, help="prefetch for dataloader"
)
args = parser.parse_args()
return args
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
if args.device == "cuda":
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
else:
world_size, local_rank, rank = 1, 0, 0
device = torch.device(args.device)
model = s3tokenizer.load_model(args.model).to(device)
dataset = AudioDataset(args.data_dir, args.jsonl_file)
if args.device == "cuda":
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank]
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
else:
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=collate_fn,
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
writer = open(f"{args.output_dir}/part_{rank + 1}_of_{world_size}", "w")
for keys, mels, mels_lens in dataloader:
codes, codes_lens = model(mels.to(device), mels_lens.to(device))
for i, k in enumerate(keys):
code = codes[i, : codes_lens[i].item()].tolist()
k["code"] = code
writer.write(json.dumps(k, ensure_ascii=False) + "\n")
if rank == 0:
progress_bar.update(world_size * len(keys))
if rank == 0:
progress_bar.close()
writer.close()
if args.device == "cuda":
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

115
egs/emilia/TTS/prepare.sh Executable file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env bash
set -eou pipefail
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
stage=3
stop_stage=4
# Please download the OpenDataLab format from HuggingFace, you can specify the revision argument to fc71e07e8572f5f3be1dbd02ed3172a4d298f152, which is the old format.
# https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07e8572f5f3be1dbd02ed3172a4d298f152
dl_dir=$PWD/download
prefix="emilia"
# zh, en, ja, ko, de, fr
lang_set=("de" "en" "zh" "ja" "ko" "fr")
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "dl_dir: $dl_dir"
log "Stage 0: Download data"
# Extract the downloaded data:
cat $dl_dir/raw/EN/EN_B00008.tar.gz.* > $dl_dir/raw/EN/EN_B00008.tar.gz
for lang in "${lang_set[@]}"; do
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
folder=$dl_dir/raw/${lang_upper}
for file in $folder/*.tar.gz; do
echo "Processing ${file}"
tar -xzvf $file -C $folder
done
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare emilia manifest (used by ./f5-tts)"
# We assume that you have downloaded the Emilia corpus
# to $dl_dir/emilia
mkdir -p data/manifests
for lang in "${lang_set[@]}"; do
echo "Processing ${lang}"
if [ ! -e data/manifests/.emilia.${lang}.done ]; then
lhotse prepare emilia $dl_dir data/manifests --num-jobs 30 --lang "${lang}"
touch data/manifests/.emilia.${lang}.done
fi
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Generate fbank (used by ./f5-tts)"
mkdir -p data/fbank
for lang in "${lang_set[@]}"; do
echo "Processing ${lang}"
if [ ! -e data/fbank/.emilia.${lang}.done ]; then
./local/compute_mel_feat.py --dataset-parts $lang --split 100 --prefix ${prefix}
touch data/fbank/.emilia.${lang}.done
fi
done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Extract cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)"
for lang in "${lang_set[@]}"; do
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
data_dir=$dl_dir/raw/${lang_upper}
# for all jsonl files in data_dir
for jsonl_file in $data_dir/*.jsonl; do
# get the file basename
jsonl_file_basename=$(basename $jsonl_file)
echo "Processing $jsonl_file"
output_dir="./cosy_v2_tokens_${lang_upper}/${jsonl_file_basename%.jsonl}"
echo "output_dir: $output_dir"
# skip if the output_dir exists
if [ -e $output_dir ]; then
echo "Output directory $output_dir already exists, skipping"
continue
fi
mkdir -p $output_dir
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
local/extract_cosyvoice2_token.py --data_dir $data_dir \
--jsonl_file $jsonl_file_basename \
--device "cuda" \
--output_dir $output_dir \
--batch_size 32 \
--num_workers 2 \
--model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz
done
done
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Merge cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)"
for lang in "${lang_set[@]}"; do
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
cosy_token_dir="./cosy_v2_tokens_${lang_upper}"
for dir in $cosy_token_dir/*; do
echo "Processing $dir"
dir_basename=$(basename $dir)
echo "dir_basename: $dir_basename"
cat $dir/part* > $dir/${dir_basename}.jsonl
done
cat $cosy_token_dir/${lang_upper}*/*.jsonl > $cosy_token_dir/cosy_v2_tokens_${lang_upper}.jsonl
done
fi

1
egs/emilia/TTS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -9,20 +9,6 @@
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. [**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
> [!CAUTION]
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities.
>
> By using this framework, you agree to the following:
> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data.
>
> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology.
>
> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required.
>
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
# [VALL-E](https://arxiv.org/abs/2301.02111) # [VALL-E](https://arxiv.org/abs/2301.02111)
./valle contains the code for training VALL-E TTS model. ./valle contains the code for training VALL-E TTS model.
@ -186,3 +172,5 @@ bash local/compute_wer.sh $output_dir $manifest
- [VALL-E](https://github.com/lifeiteng/vall-e) - [VALL-E](https://github.com/lifeiteng/vall-e)
- [F5-TTS](https://github.com/SWivid/F5-TTS) - [F5-TTS](https://github.com/SWivid/F5-TTS)
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main)
- [Spark-TTS](https://github.com/SparkAudio/Spark-TTS)

View File

@ -108,13 +108,6 @@ def get_parser():
help="Interpolate semantic token to match mel frames for CosyVoice", help="Interpolate semantic token to match mel frames for CosyVoice",
) )
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
parser.add_argument( parser.add_argument(
"--split-name", "--split-name",
type=str, type=str,

View File

@ -0,0 +1,373 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
split=test_zh
llm_path=f5-tts/exp_zh/checkpoint-805000
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
vocoder=./bigvgan_v2_24khz_100band_256x
torchrun --nproc_per_node=2 \
f5-tts/infer_dist.py \
--output_dir $output_dir \
--batch_size 1 \
--num_workers 2 \
--llm-model-name-or-path $llm_path \
--flow-matching-model-path $model_path \
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--use-cosyvoice-semantic-token True \
--vocoder-dir $vocoder \
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
"""
import argparse
import json
import os
from pathlib import Path
import s3tokenizer
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from bigvganinference import BigVGANInference
from datasets import load_dataset
from lhotse.serialization import load_jsonl
from llm_tts import LLMTTS
from model.modules import MelSpec
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from train import (
add_model_arguments,
get_model,
get_tokenizer,
interpolate_tokens,
load_F5_TTS_pretrained_checkpoint,
)
from icefall.checkpoint import load_checkpoint
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--s3-tokenizer-name",
required=False,
type=str,
choices=[
"speech_tokenizer_v1",
"speech_tokenizer_v1_25hz",
"speech_tokenizer_v2_25hz",
],
help="model version",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=4, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=5, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="model version",
)
parser.add_argument(
"--tokenizer-dir",
required=True,
type=str,
help="tokenizer dir",
)
parser.add_argument(
"--vocoder-dir",
required=True,
type=str,
help="vocoder dir",
)
parser.add_argument(
"--flow-matching-model-path",
required=True,
type=str,
help="flow matching model path",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
add_model_arguments(parser)
args = parser.parse_args()
return args
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
return padded_ref_mels
def data_collator(batch, tokenizer, mel_spectrogram):
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
"<|SPEECH_GENERATION_START|>"
)
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
target_sample_rate = 24000
hop_length = 256
target_rms = 0.1
input_ids_list, ref_mel_list, ref_mel_len_list = [], [], []
for i, item in enumerate(batch):
prompt_text, target_text, prompt_audio_codes = (
item["prompt_text"],
item["target_text"],
item["prompt_audio_cosy2_tokens"],
)
message = [
{
"role": "user",
"content": f"Convert the text to speech: {prompt_text + target_text}",
},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
]
input_ids = tokenizer.apply_chat_template(
message,
tokenize=True,
chat_template=TEMPLATE,
)
prompt_audio_codes = [c + 151665 for c in prompt_audio_codes]
idx = input_ids.index(speech_generation_start_index)
input_ids = input_ids[:idx] + prompt_audio_codes
input_ids_list.append(input_ids)
# get flow matching model's prompt mel spectrogram
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
# Duration in mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [
[tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids
for input_ids in input_ids_list
]
input_ids = torch.tensor(input_ids_list, dtype=torch.int64)
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
ids = [item["id"] for item in batch]
ref_mel_batch = padded_mel_batch(ref_mel_list)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"ids": ids,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
model = LLMTTS(
model_dir=args.llm_model_name_or_path,
tokenizer_dir=args.tokenizer_dir,
s3_tokenizer_name=args.s3_tokenizer_name,
device=device,
)
vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False)
vocoder = vocoder.eval().to(device)
flow_matching_model = get_model(args).eval().to(device)
_ = load_checkpoint(
args.flow_matching_model_path,
model=flow_matching_model,
)
dataset = load_dataset(
"yuekai/seed_tts_cosy2",
split=args.split_name,
trust_remote_code=True,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
mel_spectrogram = MelSpec(
n_fft=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24000,
mel_spec_type="bigvgan",
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram),
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
for batch in dataloader:
generate_codes = model.inference_batch(
batch["input_ids"],
batch["attention_mask"],
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
)
flow_matching_input_tokens, total_mel_lens = [], []
for i, code in enumerate(generate_codes):
flow_matching_input_token = interpolate_tokens(code)
total_mel_len = len(flow_matching_input_token)
flow_matching_input_tokens.append(flow_matching_input_token)
total_mel_lens.append(total_mel_len)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
"ref_mel_len_batch"
].to(device)
max_len = max([len(tokens) for tokens in flow_matching_input_tokens])
# pad tokens to the same length
for i, tokens in enumerate(flow_matching_input_tokens):
flow_matching_input_tokens[i] = torch.tensor(
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
)
flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device)
generated, _ = flow_matching_model.sample(
cond=ref_mels,
text=flow_matching_input_tokens,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
no_ref_audio=False,
seed=0,
)
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",
generated_wave,
target_sample_rate,
)
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
if rank == 0:
progress_bar.close()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,110 @@
# Copyright (c) 2025 SparkAudio
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
# 2025 Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py
import re
from pathlib import Path
from typing import List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class LLMTTS:
"""
LLM-TTS for text-to-speech generation.
"""
def __init__(
self,
model_dir: Path,
tokenizer_dir: Path,
s3_tokenizer_name: str,
device: torch.device,
):
"""
Initializes the LLMTTS model with the provided configurations and device.
Args:
model_dir (Path): Directory containing the model and config files.
tokenizer_dir (Path): Directory containing the tokenizer files.
s3_tokenizer_name (str): Name of the tokenizer file on S3.
device (torch.device): Device to run the model on.
"""
self.device = device
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.float16,
device_map=device,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
self.original_vocab_size = len(tokenizer)
self.cosyvoice2_token_vocab_size = 6561
new_tokens = [f"<|s_{i}|>" for i in range(self.cosyvoice2_token_vocab_size)] + [
"<|SPEECH_GENERATION_START|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)
tokenizer.padding_side = "left"
self.tokenizer = tokenizer
self.assistant_index = tokenizer.convert_tokens_to_ids("assistant")
@torch.no_grad()
def inference_batch(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 0.8,
top_k: float = 50,
top_p: float = 0.95,
max_new_tokens: int = 1024,
) -> torch.Tensor:
"""
Performs inference to generate speech from text, incorporating prompt audio and/or text.
Args:
input_ids (torch.Tensor): Input IDs for the model.
attention_mask (torch.Tensor): Attention mask for the model.
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
top_k (float, optional): Top-k sampling parameter. Default is 50.
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
max_new_tokens (int, optional): Maximum number of tokens to generate. Default is 1024.
Returns:
torch.Tensor: Generated waveform as a tensor.
"""
generated_ids = self.model.generate(
input_ids=input_ids.to(self.device),
attention_mask=attention_mask.to(self.device),
max_new_tokens=max_new_tokens,
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)
results = []
generated_ids = generated_ids.cpu().tolist()
for i in range(len(generated_ids)):
assistant_index = generated_ids[i].index(self.assistant_index)
padding_index = len(generated_ids[i])
# WAR: harding coding assistant_index + 2, for the current template Assistant: \n
result = generated_ids[i][assistant_index + 2 :]
result = [token - self.original_vocab_size for token in result]
result = [token for token in result if token >= 0]
results.append(result)
return results

View File

@ -118,6 +118,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of Decoder layers.", help="Number of Decoder layers.",
) )
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -313,13 +320,6 @@ def get_parser():
help="perform OOM check on dataloader batches before starting training.", help="perform OOM check on dataloader batches before starting training.",
) )
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser