diff --git a/egs/emilia/TTS/README.md b/egs/emilia/TTS/README.md new file mode 100644 index 000000000..363ea3842 --- /dev/null +++ b/egs/emilia/TTS/README.md @@ -0,0 +1,94 @@ +# Results +| LLM Model | Flow matching Model | Seed-TTS test_zh CER | Comment | +|---------------------------------------|----------|-----------|--------| +| pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)| +| pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)| +| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.89% (16 steps) | | + +# Introduction + +[**Emilia**](https://huggingface.co/datasets/amphion/Emilia-Dataset) starts with over 101k +hours of speech across six languages, covering a wide range of speaking styles to enable more natural and spontaneous speech generation. + +See https://arxiv.org/pdf/2407.05361. + +> [!CAUTION] +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. +> +> By using this framework, you agree to the following: +> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. +> +> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. +> +> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. +> +> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. + + + + +# Llasa (cosyvoice2 token) + +./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens. + +Generated samples and training logs of [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) 50k hours Chinese data can be found [here](https://huggingface.co/yuekai/llasa_cosyvoice2_token_qwen_0.5b/tree/main). + +Preparation: + +``` +# extract cosyvoice2 semantic tokens +bash prepare.sh --stage 3 --stop_stage 4 + +# Or you could use the prepared tokens. +huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token +``` + +The training command is given below: + +``` +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install -r llasa_cosyvoice2_token/requirements.txt +# pip install -r icefall/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt + +WANDB_KEY=$your_wandb_key +wandb login ${WANDB_KEY} +huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct +torchrun --nproc_per_node=8 train.py config.json +``` + +To inference with Icefall Emilia trained Chinese Llasa_cosyvoice2_token model, we need to use cosyvoice2 token flow matching [model](https://github.com/k2-fsa/icefall/pull/1880): +``` +cd icefall/egs/wenetspeech4tts/TTS +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/llasa_cosyvoice2_token_qwen_0.5b +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x +vocoder=./bigvgan_v2_24khz_100band_256x +split=test_zh +llm_path=llasa_cosyvoice2_token_qwen_0.5b/checkpoint-800000 + +huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt +torchrun --nproc_per_node=2 \ + f5-tts/infer_dist.py \ + --output_dir $output_dir \ + --batch_size 1 \ + --num_workers 2 \ + --llm-model-name-or-path $llm_path \ + --flow-matching-model-path $model_path \ + --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --use-cosyvoice-semantic-token True \ + --vocoder-dir $vocoder \ + --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \ + --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct +# compute cer +huggingface-cli download yuekai/seed_tts_eval --local-dir seed_tts_eval --repo-type dataset +manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst +bash local/compute_wer.sh $output_dir $manifest +``` + +# Credits +- [Llasa](https://arxiv.org/abs/2502.04128) +- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) +- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main) diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/config.json b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json index 06aeb51f1..858edae84 100644 --- a/egs/emilia/TTS/llasa_cosyvoice2_token/config.json +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json @@ -1,6 +1,6 @@ { - "llm_model_name_or_path": "/workspace/slam/icefall_omni/egs/speech_llm/SPEECH2SPEECH/models/Qwen2.5-0.5B-Instruct", - "data_path": ["../emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"], + "llm_model_name_or_path": "./Qwen2.5-0.5B-Instruct", + "data_path": ["./emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"], "bf16": false, "output_dir": "./exp_zh", "num_train_epochs": 3, diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt index 09e069d3a..11574c190 100644 --- a/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt @@ -5,3 +5,4 @@ datasets accelerate>=0.26.0 deepspeed flash-attn +s3tokenizer diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/train.py b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py index 159e483d7..e3c6fcae6 100644 --- a/egs/emilia/TTS/llasa_cosyvoice2_token/train.py +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py @@ -1,3 +1,11 @@ +# Modified from https://github.com/zhenye234/LLaSA_training/blob/main/train_tts.py +""" Example Usage +WANDB_KEY=$your_wandb_key +wandb login ${WANDB_KEY} +huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token +huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct +torchrun --nproc_per_node=8 train.py config.json +""" import json import os import random @@ -11,8 +19,7 @@ import torch import torch.nn as nn import transformers import wandb -from datasets import load_dataset, load_from_disk -from torch.utils.data import DataLoader, Dataset +from datasets import load_dataset from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -65,7 +72,7 @@ class CustomTrainingArguments(TrainingArguments): remove_unused_columns: bool = field(default=False) -def data_collator(batch, tokenizer): +def data_collator(batch, tokenizer, original_tokenizer_vocab_size, cut_off_len=2048): speech_generation_start_index = tokenizer.convert_tokens_to_ids( "<|SPEECH_GENERATION_START|>" ) @@ -84,11 +91,11 @@ def data_collator(batch, tokenizer): chat_template=TEMPLATE, ) - code = [c + 151665 for c in code] + code = [c + original_tokenizer_vocab_size for c in code] idx = input_ids.index(speech_generation_start_index) input_ids = input_ids[:idx] + code + input_ids[idx + 1 :] - if len(input_ids) < 2048: + if len(input_ids) < cut_off_len: input_ids_list.append(input_ids) max_len = max([len(input_ids) for input_ids in input_ids_list]) @@ -140,7 +147,11 @@ def main(): ) tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path) - new_tokens = [f"<|s_{i}|>" for i in range(6561)] + ["<|SPEECH_GENERATION_START|>"] + original_tokenizer_vocab_size = len(tokenizer) + cosyvoice2_token_size = 6561 + new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [ + "<|SPEECH_GENERATION_START|>" + ] num_added_tokens = tokenizer.add_tokens(new_tokens) model.resize_token_embeddings(len(tokenizer)) @@ -157,7 +168,9 @@ def main(): args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, - data_collator=lambda features: data_collator(features, tokenizer), + data_collator=lambda features: data_collator( + features, tokenizer, original_tokenizer_vocab_size + ), ) if is_main_process: diff --git a/egs/emilia/TTS/local/extract_cosyvoice2_token.py b/egs/emilia/TTS/local/extract_cosyvoice2_token.py index 2c1ccda76..2a6d1d380 100644 --- a/egs/emilia/TTS/local/extract_cosyvoice2_token.py +++ b/egs/emilia/TTS/local/extract_cosyvoice2_token.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# 2025 (authors: Yuekai Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Example Usage -cpu: - -s3tokenizer --data_dir xxx.scp \ - --device "cpu" \ - --output_dir "./" \ - --batch_size 32 - -gpu: - torchrun --nproc_per_node=8 --nnodes=1 \ - --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ - `which s3tokenizer` --data_dir xxx.scp \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + local/extract_cosyvoice2_token.py --data_dir $data_dir \ + --jsonl_file $jsonl_file_basename \ --device "cuda" \ - --output_dir "./" \ - --batch_size 32 + --output_dir $output_dir \ + --batch_size 32 \ + --num_workers 2 \ + --model "speech_tokenizer_v2_25hz" """ diff --git a/egs/emilia/TTS/prepare.sh b/egs/emilia/TTS/prepare.sh index 4a0d2df0b..8abcfaf61 100755 --- a/egs/emilia/TTS/prepare.sh +++ b/egs/emilia/TTS/prepare.sh @@ -4,16 +4,17 @@ set -eou pipefail # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python -# pip install lhotse s3tokenizer -stage=6 -stop_stage=6 +stage=3 +stop_stage=4 + +# Please download the OpenDataLab format from HuggingFace, you can specify the revision argument to fc71e07e8572f5f3be1dbd02ed3172a4d298f152, which is the old format. +# https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07e8572f5f3be1dbd02ed3172a4d298f152 dl_dir=$PWD/download -dl_dir=/workspace_data/Emilia-Dataset/ + prefix="emilia" # zh, en, ja, ko, de, fr lang_set=("de" "en" "zh" "ja" "ko" "fr") -lang_set=("de" "en" "zh" "ja" "fr") . shared/parse_options.sh || exit 1 @@ -29,23 +30,20 @@ log() { if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "dl_dir: $dl_dir" log "Stage 0: Download data" - #huggingface-cli login - # huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS - # Extract the downloaded data: + cat $dl_dir/raw/EN/EN_B00008.tar.gz.* > $dl_dir/raw/EN/EN_B00008.tar.gz for lang in "${lang_set[@]}"; do lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') folder=$dl_dir/raw/${lang_upper} for file in $folder/*.tar.gz; do echo "Processing ${file}" - # e.g. $dl_dir/raw/DE/*tar.gz untar first, DE is the language code in upper case tar -xzvf $file -C $folder done done fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare emilia manifest" + log "Stage 1: Prepare emilia manifest (used by ./f5-tts)" # We assume that you have downloaded the Emilia corpus # to $dl_dir/emilia mkdir -p data/manifests @@ -58,7 +56,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then done fi - if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Generate fbank (used by ./f5-tts)" mkdir -p data/fbank @@ -71,67 +68,8 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then done fi -if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then - log "Stage 6: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)" - if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then - echo "Combining ${prefix} cuts" - pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") - lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz - fi - if [ ! -e data/fbank/.${prefix}_split.done ]; then - echo "Splitting ${prefix} cuts into train, valid and test sets" - - lhotse subset --last 800 \ - data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz - lhotse subset --first 400 \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz \ - data/fbank/${prefix}_cuts_valid.jsonl.gz - lhotse subset --last 400 \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz \ - data/fbank/${prefix}_cuts_test.jsonl.gz - - rm data/fbank/${prefix}_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) - lhotse subset --first $n \ - data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ - data/fbank/${prefix}_cuts_train.jsonl.gz - touch data/fbank/.${prefix}_split.done - fi -fi - -# zcat test.jsonl.gz | jq -r '.recording.id + " " + .recording.sources[0].source' > wav.scp -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)" - data_dir=$dl_dir/raw/ZH - # for all jsonl files in data_dir - for jsonl_file in $data_dir/*.jsonl; do - # get the file basename - jsonl_file_basename=$(basename $jsonl_file) - echo "Processing $jsonl_file" - output_dir="./cosy_v2_tokens_ZH/${jsonl_file_basename%.jsonl}" - echo "output_dir: $output_dir" - # skip if the output_dir exists - if [ -e $output_dir ]; then - echo "Output directory $output_dir already exists, skipping" - continue - fi - mkdir -p $output_dir - torchrun --nproc_per_node=8 --nnodes=1 \ - --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ - local/extract_cosyvoice2_token.py --data_dir $data_dir \ - --jsonl_file $jsonl_file_basename \ - --device "cuda" \ - --output_dir $output_dir \ - --batch_size 32 \ - --num_workers 2 \ - --model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz - done -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)" +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Extract cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)" for lang in "${lang_set[@]}"; do lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') data_dir=$dl_dir/raw/${lang_upper} @@ -161,14 +99,13 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then done fi -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then -# cat EN_B00008.tar.gz.* > EN_B00008.tar.gz +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Merge cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)" for lang in "${lang_set[@]}"; do lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') cosy_token_dir="./cosy_v2_tokens_${lang_upper}" for dir in $cosy_token_dir/*; do echo "Processing $dir" - # get the file basename dir_basename=$(basename $dir) echo "dir_basename: $dir_basename" cat $dir/part* > $dir/${dir_basename}.jsonl diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md index 8329ae948..9a48bd196 100644 --- a/egs/wenetspeech4tts/TTS/README.md +++ b/egs/wenetspeech4tts/TTS/README.md @@ -186,3 +186,5 @@ bash local/compute_wer.sh $output_dir $manifest - [VALL-E](https://github.com/lifeiteng/vall-e) - [F5-TTS](https://github.com/SWivid/F5-TTS) - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) +- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main) +- [Spark-TTS](https://github.com/SparkAudio/Spark-TTS) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py index 6964a43be..b90657d0e 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -108,13 +108,6 @@ def get_parser(): help="Interpolate semantic token to match mel frames for CosyVoice", ) - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - parser.add_argument( "--split-name", type=str, diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py index 59e222a74..636720f03 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# 2025 (authors: Yuekai Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,23 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py """ Example Usage -cpu: - -s3tokenizer --data_dir xxx.scp \ - --device "cpu" \ - --output_dir "./" \ - --batch_size 32 - -gpu: - -torchrun --nproc_per_node=8 --nnodes=1 \ - --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ - `which s3tokenizer` --data_dir xxx.scp \ - --device "cuda" \ - --output_dir "./" \ - --batch_size 32 - +split=test_zh +llm_path=f5-tts/exp_zh/checkpoint-805000 +huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x +vocoder=./bigvgan_v2_24khz_100band_256x +torchrun --nproc_per_node=2 \ + f5-tts/infer_dist.py \ + --output_dir $output_dir \ + --batch_size 1 \ + --num_workers 2 \ + --llm-model-name-or-path $llm_path \ + --flow-matching-model-path $model_path \ + --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --use-cosyvoice-semantic-token True \ + --vocoder-dir $vocoder \ + --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \ + --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct """ import argparse @@ -81,16 +85,16 @@ def get_args(): help="huggingface dataset split name", ) parser.add_argument( - "--output_dir", required=True, type=str, help="dir to save result" + "--output-dir", required=True, type=str, help="dir to save result" ) parser.add_argument( - "--batch_size", + "--batch-size", required=True, type=int, help="batch size (per-device) for inference", ) parser.add_argument( - "--num_workers", type=int, default=4, help="workers for dataloader" + "--num-workers", type=int, default=4, help="workers for dataloader" ) parser.add_argument( "--prefetch", type=int, default=5, help="prefetch for dataloader" @@ -119,6 +123,24 @@ def get_args(): type=str, help="flow matching model path", ) + parser.add_argument( + "--top-k", + type=int, + default=50, + help="top k for sampling", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="top p for sampling", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature for sampling", + ) add_model_arguments(parser) args = parser.parse_args() return args @@ -285,7 +307,11 @@ def main(): for batch in dataloader: generate_codes = model.inference_batch( - batch["input_ids"], batch["attention_mask"] + batch["input_ids"], + batch["attention_mask"], + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, ) flow_matching_input_tokens, total_mel_lens = [], [] for i, code in enumerate(generate_codes): diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py deleted file mode 100644 index 6964a43be..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py +++ /dev/null @@ -1,828 +0,0 @@ -#!/usr/bin/env python3 -# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py -""" -Usage: -# docker: ghcr.io/swivid/f5-tts:main -# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html -# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx -# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x -manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst -python3 f5-tts/generate_averaged_model.py \ - --epoch 56 \ - --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ - --exp-dir exp/f5_small - -# command for text token input -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 - -# command for cosyvoice semantic token input -split=test_zh # seed_tts_eval test_zh -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True - -bash local/compute_wer.sh $output_dir $manifest -""" -import argparse -import logging -import math -import os -import random -import time -from pathlib import Path - -import datasets -import torch -import torch.nn.functional as F -import torchaudio -from accelerate import Accelerator -from bigvganinference import BigVGANInference -from model.cfm import CFM -from model.dit import DiT -from model.modules import MelSpec -from model.utils import convert_char_to_pinyin -from tqdm import tqdm -from train import ( - add_model_arguments, - get_model, - get_tokenizer, - interpolate_tokens, - load_F5_TTS_pretrained_checkpoint, -) - -from icefall.checkpoint import load_checkpoint -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--tokens", - type=str, - default="f5-tts/vocab.txt", - help="Path to the unique text tokens file", - ) - - parser.add_argument( - "--model-path", - type=str, - default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", - help="Path to the unique text tokens file", - ) - - parser.add_argument( - "--seed", - type=int, - default=0, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--nfe", - type=int, - default=16, - help="The number of steps for the neural ODE", - ) - - parser.add_argument( - "--manifest-file", - type=str, - default=None, - help="The manifest file in seed_tts_eval format", - ) - - parser.add_argument( - "--output-dir", - type=Path, - default="results", - help="The output directory to save the generated wavs", - ) - - parser.add_argument("-ss", "--swaysampling", default=-1, type=float) - - parser.add_argument( - "--interpolate-token", - type=str2bool, - default=True, - help="Interpolate semantic token to match mel frames for CosyVoice", - ) - - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - - parser.add_argument( - "--split-name", - type=str, - default="wenetspeech4tts", - choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], - help="huggingface dataset split name", - ) - - add_model_arguments(parser) - return parser.parse_args() - - -def get_inference_prompt( - metainfo, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, -): - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( - metainfo, desc="Processing prompts..." - ): - # Audio - ref_audio, ref_sr = torchaudio.load(prompt_wav) - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) - if ref_rms < target_rms: - ref_audio = ref_audio * target_rms / ref_rms - assert ( - ref_audio.shape[-1] > 5000 - ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio) - - # Text - if len(prompt_text[-1].encode("utf-8")) == 1: - prompt_text = prompt_text + " " - text = [prompt_text + gt_text] - if tokenizer == "pinyin": - text_list = convert_char_to_pinyin(text, polyphone=polyphone) - else: - text_list = text - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - if use_truth_duration: - gt_audio, gt_sr = torchaudio.load(gt_wav) - if gt_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) - gt_audio = resampler(gt_audio) - total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) - - # # test vocoder resynthesis - # ref_audio = gt_audio - else: - ref_text_len = len(prompt_text.encode("utf-8")) - gen_text_len = len(gt_text.encode("utf-8")) - total_mel_len = ref_mel_len + int( - ref_mel_len / ref_text_len * gen_text_len / speed - ) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - final_text_list[bucket_i].extend(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def get_inference_prompt_cosy_voice_huggingface( - dataset, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, - interpolate_token=False, -): - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for i in range(len(dataset)): - utt = dataset[i]["id"] - ref_audio_org, ref_sr = ( - dataset[i]["prompt_audio"]["array"], - dataset[i]["prompt_audio"]["sampling_rate"], - ) - ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() - audio_tokens = dataset[i]["target_audio_cosy2_tokens"] - prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"] - - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) - if ref_rms < target_rms: - ref_audio_org = ref_audio_org * target_rms / ref_rms - - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio_org) - else: - ref_audio = ref_audio_org - input_tokens = prompt_audio_tokens + audio_tokens - - if interpolate_token: - input_tokens = interpolate_tokens(input_tokens) - text_list = input_tokens - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - - total_mel_len = len(input_tokens) - if not interpolate_token: - total_mel_len = int(total_mel_len / 4 * 15) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - if total_mel_len > max_tokens: - print( - f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - ) - continue - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - # final_text_list[bucket_i].extend(text_list) - final_text_list[bucket_i].append(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def inference_speech_token( - cosyvoice, - tts_text, - prompt_text, - prompt_speech_16k, - stream=False, - speed=1.0, - text_frontend=True, -): - tokens = [] - prompt_text = cosyvoice.frontend.text_normalize( - prompt_text, split=False, text_frontend=text_frontend - ) - for i in cosyvoice.frontend.text_normalize( - tts_text, split=True, text_frontend=text_frontend - ): - - tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i) - ( - prompt_text_token, - prompt_text_token_len, - ) = cosyvoice.frontend._extract_text_token(prompt_text) - speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token( - prompt_speech_16k - ) - - for i in cosyvoice.model.llm.inference( - text=tts_text_token.to(cosyvoice.model.device), - text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to( - cosyvoice.model.device - ), - prompt_text=prompt_text_token.to(cosyvoice.model.device), - prompt_text_len=torch.tensor( - [prompt_text_token.shape[1]], dtype=torch.int32 - ).to(cosyvoice.model.device), - prompt_speech_token=speech_token.to(cosyvoice.model.device), - prompt_speech_token_len=torch.tensor( - [speech_token.shape[1]], dtype=torch.int32 - ).to(cosyvoice.model.device), - embedding=None, - ): - tokens.append(i) - return tokens, speech_token - - -def get_inference_prompt_cosy_voice( - metainfo, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, - interpolate_token=False, -): - - import sys - - # please change the path to the cosyvoice accordingly - sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") - sys.path.append("/workspace/CosyVoice") - from cosyvoice.cli.cosyvoice import CosyVoice2 - - # please download the cosyvoice model first - cosyvoice = CosyVoice2( - "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False - ) - - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( - metainfo, desc="Processing prompts..." - ): - # Audio - ref_audio_org, ref_sr = torchaudio.load(prompt_wav) - - # cosy voice - if ref_sr != 16000: - resampler = torchaudio.transforms.Resample(ref_sr, 16000) - ref_audio_16k = resampler(ref_audio_org) - else: - ref_audio_16k = ref_audio_org - audio_tokens, prompt_audio_tokens = inference_speech_token( - cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False - ) - - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) - if ref_rms < target_rms: - ref_audio_org = ref_audio_org * target_rms / ref_rms - assert ( - ref_audio_org.shape[-1] > 5000 - ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio_org) - else: - ref_audio = ref_audio_org - - # Text - # if len(prompt_text[-1].encode("utf-8")) == 1: - # prompt_text = prompt_text + " " - # text = [prompt_text + gt_text] - # if tokenizer == "pinyin": - # text_list = convert_char_to_pinyin(text, polyphone=polyphone) - # else: - # text_list = text - - # concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens - # prompt_audio_tokens shape 1, prompt_audio_tokens - # audio_tokens shape 1, audio_tokens - prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist() - input_tokens = prompt_audio_tokens + audio_tokens - - # convert it into a list - # input_tokens_list = input_tokens.squeeze().cpu().tolist() - if interpolate_token: - input_tokens = interpolate_tokens(input_tokens) - text_list = input_tokens - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - if use_truth_duration: - gt_audio, gt_sr = torchaudio.load(gt_wav) - if gt_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) - gt_audio = resampler(gt_audio) - total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) - - # # test vocoder resynthesis - # ref_audio = gt_audio - else: - ref_text_len = len(prompt_text.encode("utf-8")) - gen_text_len = len(gt_text.encode("utf-8")) - total_mel_len_compute = ref_mel_len + int( - ref_mel_len / ref_text_len * gen_text_len / speed - ) - total_mel_len = len(input_tokens) - if not interpolate_token: - total_mel_len = int(total_mel_len / 4 * 15) - print( - f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" - ) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - # final_text_list[bucket_i].extend(text_list) - final_text_list[bucket_i].append(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def padded_mel_batch(ref_mels): - max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() - padded_ref_mels = [] - for mel in ref_mels: - padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) - padded_ref_mels.append(padded_ref_mel) - padded_ref_mels = torch.stack(padded_ref_mels) - padded_ref_mels = padded_ref_mels.permute(0, 2, 1) - return padded_ref_mels - - -def get_seedtts_testset_metainfo(metalst): - f = open(metalst) - lines = f.readlines() - f.close() - metainfo = [] - for line in lines: - assert len(line.strip().split("|")) == 4 - utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") - utt = Path(utt).stem - gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") - if not os.path.isabs(prompt_wav): - prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) - metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) - return metainfo - - -def main(): - args = get_parser() - - accelerator = Accelerator() - device = f"cuda:{accelerator.process_index}" - if args.manifest_file: - metainfo = get_seedtts_testset_metainfo(args.manifest_file) - if not args.use_cosyvoice_semantic_token: - prompts_all = get_inference_prompt( - metainfo, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - ) - else: - prompts_all = get_inference_prompt_cosy_voice( - metainfo, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - interpolate_token=args.interpolate_token, - ) - else: - assert args.use_cosyvoice_semantic_token - dataset = datasets.load_dataset( - "yuekai/seed_tts_cosy2", - split=args.split_name, - trust_remote_code=True, - ) - prompts_all = get_inference_prompt_cosy_voice_huggingface( - dataset, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - interpolate_token=args.interpolate_token, - ) - - vocoder = BigVGANInference.from_pretrained( - "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False - ) - vocoder = vocoder.eval().to(device) - - model = get_model(args).eval().to(device) - checkpoint = torch.load(args.model_path, map_location="cpu") - if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: - model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) - else: - _ = load_checkpoint( - args.model_path, - model=model, - ) - - os.makedirs(args.output_dir, exist_ok=True) - - accelerator.wait_for_everyone() - start = time.time() - - with accelerator.split_between_processes(prompts_all) as prompts: - for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): - ( - utts, - ref_rms_list, - ref_mels, - ref_mel_lens, - total_mel_lens, - final_text_list, - ) = prompt - ref_mels = ref_mels.to(device) - ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) - total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) - - if args.use_cosyvoice_semantic_token: - # concat final_text_list - max_len = max([len(tokens) for tokens in final_text_list]) - # pad tokens to the same length - for i, tokens in enumerate(final_text_list): - final_text_list[i] = torch.tensor( - tokens + [-1] * (max_len - len(tokens)), dtype=torch.long - ) - final_text_list = torch.stack(final_text_list).to(device) - - # Inference - with torch.inference_mode(): - generated, _ = model.sample( - cond=ref_mels, - text=final_text_list, - duration=total_mel_lens, - lens=ref_mel_lens, - steps=args.nfe, - cfg_strength=2.0, - sway_sampling_coef=args.swaysampling, - no_ref_audio=False, - seed=args.seed, - ) - for i, gen in enumerate(generated): - gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) - gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) - - generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() - target_rms = 0.1 - target_sample_rate = 24_000 - if ref_rms_list[i] < target_rms: - generated_wave = generated_wave * ref_rms_list[i] / target_rms - torchaudio.save( - f"{args.output_dir}/{utts[i]}.wav", - generated_wave, - target_sample_rate, - ) - - accelerator.wait_for_everyone() - if accelerator.is_main_process: - timediff = time.time() - start - print(f"Done batch inference in {timediff / 60 :.2f} minutes.") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py index bf878db51..1d0fdc5c8 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py @@ -1,5 +1,6 @@ # Copyright (c) 2025 SparkAudio # 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# 2025 Yuekai Zhang # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py +# Modified from https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py import re from pathlib import Path @@ -39,7 +40,9 @@ class LLMTTS: Args: model_dir (Path): Directory containing the model and config files. - device (torch.device): The device (CPU/GPU) to run the model on. + tokenizer_dir (Path): Directory containing the tokenizer files. + s3_tokenizer_name (str): Name of the tokenizer file on S3. + device (torch.device): Device to run the model on. """ self.device = device @@ -51,7 +54,9 @@ class LLMTTS: ) tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) - new_tokens = [f"<|s_{i}|>" for i in range(6561)] + [ + self.original_vocab_size = len(tokenizer) + self.cosyvoice2_token_vocab_size = 6561 + new_tokens = [f"<|s_{i}|>" for i in range(self.cosyvoice2_token_vocab_size)] + [ "<|SPEECH_GENERATION_START|>" ] num_added_tokens = tokenizer.add_tokens(new_tokens) @@ -67,42 +72,39 @@ class LLMTTS: temperature: float = 0.8, top_k: float = 50, top_p: float = 0.95, + max_new_tokens: int = 1024, ) -> torch.Tensor: """ Performs inference to generate speech from text, incorporating prompt audio and/or text. Args: - text (str): The text input to be converted to speech. - prompt_speech_path (Path): Path to the audio file used as a prompt. - prompt_text (str, optional): Transcript of the prompt audio. - gender (str): female | male. - pitch (str): very_low | low | moderate | high | very_high - speed (str): very_low | low | moderate | high | very_high + input_ids (torch.Tensor): Input IDs for the model. + attention_mask (torch.Tensor): Attention mask for the model. temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. top_k (float, optional): Top-k sampling parameter. Default is 50. top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. + max_new_tokens (int, optional): Maximum number of tokens to generate. Default is 1024. Returns: torch.Tensor: Generated waveform as a tensor. """ - # Generate speech using the model generated_ids = self.model.generate( input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), - max_new_tokens=1024, + max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, ) - results = [] generated_ids = generated_ids.cpu().tolist() for i in range(len(generated_ids)): assistant_index = generated_ids[i].index(self.assistant_index) padding_index = len(generated_ids[i]) + # WAR: harding coding assistant_index + 2, for the current template Assistant: \n result = generated_ids[i][assistant_index + 2 :] - result = [token - 151665 for token in result] + result = [token - self.original_vocab_size for token in result] result = [token for token in result if token >= 0] results.append(result) return results diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 5333b3f27..343d0c65c 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -118,6 +118,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of Decoder layers.", ) + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -313,13 +320,6 @@ def get_parser(): help="perform OOM check on dataloader batches before starting training.", ) - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - add_model_arguments(parser) return parser