update results

This commit is contained in:
root 2025-02-20 08:48:11 +00:00
parent a54a0469a2
commit 2edaf685e1
3 changed files with 222 additions and 72 deletions

View File

@ -1,3 +1,10 @@
# Results
| Model | Seed-TTS test_zh CER | Comment |
|---------------------------------------|---------------------|--------|
| [vall-e](./valle) | 4.33% | ~150M |
| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M |
| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M |
# Introduction
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
@ -131,6 +138,53 @@ accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-f
bash local/compute_wer.sh $output_dir $manifest
```
# F5-TTS-Semantic-Token
./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens.
We observed faster convergence and better prosody modeling results by doing this.
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main).
Preparation:
```
# extract cosyvoice2 semantic tokens
bash prepare.sh --stage 5 --stop_stage 7
```
The training command is given below:
```
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
world_size=8
exp_dir=exp/f5-tts-semantic-token-small
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
--num-epochs 10 --start-epoch 1 --start-batch 0 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size} \
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
```
To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use:
```
huggingface-cli login
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
split=test_zh
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
bash local/compute_wer.sh $output_dir $manifest
```
# Credits
- [VALL-E](https://github.com/lifeiteng/vall-e)
- [F5-TTS](https://github.com/SWivid/F5-TTS)
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)

View File

@ -11,7 +11,14 @@ python3 f5-tts/generate_averaged_model.py \
--epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small
# command for text token input
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
# command for cosyvoice semantic token input
split=test_zh # seed_tts_eval test_zh
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
bash local/compute_wer.sh $output_dir $manifest
"""
import argparse
@ -37,11 +44,12 @@ from train import (
add_model_arguments,
get_model,
get_tokenizer,
insert_zeros_optimized,
interpolate_tokens,
load_F5_TTS_pretrained_checkpoint,
)
from icefall.checkpoint import load_checkpoint
from icefall.utils import str2bool
def get_parser():
@ -94,9 +102,17 @@ def get_parser():
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument(
"--insert-zero",
action="store_true",
help="Insert zeros for CosyVoice",
"--interpolate-token",
type=str2bool,
default=True,
help="Interpolate semantic token to match mel frames for CosyVoice",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
parser.add_argument(
@ -277,7 +293,7 @@ def get_inference_prompt_cosy_voice_huggingface(
num_buckets=200,
min_secs=3,
max_secs=40,
insert_zero=False,
interpolate_token=False,
):
prompts_all = []
@ -319,15 +335,15 @@ def get_inference_prompt_cosy_voice_huggingface(
ref_audio = ref_audio_org
input_tokens = prompt_audio_tokens + audio_tokens
if insert_zero:
input_tokens = insert_zeros_optimized(input_tokens)
if interpolate_token:
input_tokens = interpolate_tokens(input_tokens)
text_list = input_tokens
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
total_mel_len = len(input_tokens)
if not insert_zero:
if not interpolate_token:
total_mel_len = int(total_mel_len / 4 * 15)
# to mel spectrogram
@ -406,6 +422,51 @@ def get_inference_prompt_cosy_voice_huggingface(
return prompts_all
def inference_speech_token(
cosyvoice,
tts_text,
prompt_text,
prompt_speech_16k,
stream=False,
speed=1.0,
text_frontend=True,
):
tokens = []
prompt_text = cosyvoice.frontend.text_normalize(
prompt_text, split=False, text_frontend=text_frontend
)
for i in cosyvoice.frontend.text_normalize(
tts_text, split=True, text_frontend=text_frontend
):
tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i)
(
prompt_text_token,
prompt_text_token_len,
) = cosyvoice.frontend._extract_text_token(prompt_text)
speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token(
prompt_speech_16k
)
for i in cosyvoice.model.llm.inference(
text=tts_text_token.to(cosyvoice.model.device),
text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(
cosyvoice.model.device
),
prompt_text=prompt_text_token.to(cosyvoice.model.device),
prompt_text_len=torch.tensor(
[prompt_text_token.shape[1]], dtype=torch.int32
).to(cosyvoice.model.device),
prompt_speech_token=speech_token.to(cosyvoice.model.device),
prompt_speech_token_len=torch.tensor(
[speech_token.shape[1]], dtype=torch.int32
).to(cosyvoice.model.device),
embedding=None,
):
tokens.append(i)
return tokens, speech_token
def get_inference_prompt_cosy_voice(
metainfo,
speed=1.0,
@ -423,18 +484,21 @@ def get_inference_prompt_cosy_voice(
num_buckets=200,
min_secs=3,
max_secs=40,
insert_zero=False,
interpolate_token=False,
):
import sys
# please change the path to the cosyvoice accordingly
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
sys.path.append("/workspace/CosyVoice")
from cosyvoice.cli.cosyvoice import CosyVoice2
# please download the cosyvoice model first
cosyvoice = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
)
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
@ -466,8 +530,8 @@ def get_inference_prompt_cosy_voice(
ref_audio_16k = resampler(ref_audio_org)
else:
ref_audio_16k = ref_audio_org
audio_tokens, prompt_audio_tokens = cosyvoice.inference_speech_token(
gt_text, prompt_text, ref_audio_16k, stream=False
audio_tokens, prompt_audio_tokens = inference_speech_token(
cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False
)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
@ -499,8 +563,8 @@ def get_inference_prompt_cosy_voice(
# convert it into a list
# input_tokens_list = input_tokens.squeeze().cpu().tolist()
if insert_zero:
input_tokens = insert_zeros_optimized(input_tokens)
if interpolate_token:
input_tokens = interpolate_tokens(input_tokens)
text_list = input_tokens
# Duration, mel frame length
@ -521,7 +585,7 @@ def get_inference_prompt_cosy_voice(
ref_mel_len / ref_text_len * gen_text_len / speed
)
total_mel_len = len(input_tokens)
if not insert_zero:
if not interpolate_token:
total_mel_len = int(total_mel_len / 4 * 15)
print(
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
@ -632,19 +696,20 @@ def main():
device = f"cuda:{accelerator.process_index}"
if args.manifest_file:
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
# prompts_all = get_inference_prompt(
# metainfo,
# speed=1.0,
# tokenizer="pinyin",
# target_sample_rate=24_000,
# n_mel_channels=100,
# hop_length=256,
# mel_spec_type="bigvgan",
# target_rms=0.1,
# use_truth_duration=False,
# infer_batch_size=1,
# )
if not args.use_cosyvoice_semantic_token:
prompts_all = get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
)
else:
prompts_all = get_inference_prompt_cosy_voice(
metainfo,
speed=1.0,
@ -656,9 +721,10 @@ def main():
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
insert_zero=args.insert_zero,
interpolate_token=args.interpolate_token,
)
else:
assert args.use_cosyvoice_semantic_token
dataset = datasets.load_dataset(
"yuekai/seed_tts_cosy2",
split=args.split_name,
@ -675,7 +741,7 @@ def main():
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
insert_zero=args.insert_zero,
interpolate_token=args.interpolate_token,
)
vocoder = BigVGANInference.from_pretrained(
@ -712,6 +778,7 @@ def main():
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
if args.use_cosyvoice_semantic_token:
# concat final_text_list
max_len = max([len(tokens) for tokens in final_text_list])
# pad tokens to the same length

View File

@ -31,6 +31,16 @@ python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-ma
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size}
# command for training with cosyvoice semantic token
exp_dir=exp/f5-tts-cosyvoice
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
--num-epochs 10 --start-epoch 1 --start-batch 0 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size} \
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
"""
import argparse
@ -303,6 +313,13 @@ def get_parser():
help="perform OOM check on dataloader batches before starting training.",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
add_model_arguments(parser)
return parser
@ -378,9 +395,11 @@ def get_tokenizer(vocab_file_path: str):
def get_model(params):
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
vocab_char_map, vocab_size = None, 6561
if params.use_cosyvoice_semantic_token:
# https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36
vocab_char_map, vocab_size = None, 6561
else:
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
# bigvgan 100 dim features
n_mel_channels = 100
n_fft = 1024
@ -558,37 +577,43 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
def insert_zeros_optimized(arr):
def interpolate_tokens(cosy_tokens, pad_token=-1):
"""Interpolate cosyvoice tokens to match bigvgan frames length"""
# cosyvoice, 25 tokens/sec
# bigvgan sample_rate/hop_length 24000/256 frames/sec
# For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length
# We choose 4,4,4,3 to match 15 frames
three, two = [-1] * 3, [-1] * 2
three, two = [pad_token] * 3, [pad_token] * 2
return [
x for i, e in enumerate(arr) for x in ([e] + three if i % 4 < 3 else [e] + two)
x
for i, e in enumerate(cosy_tokens)
for x in ([e] + three if i % 4 < 3 else [e] + two)
]
def prepare_input(batch: dict, device: torch.device):
def prepare_input(
batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool
):
"""Parse batch data"""
# text_inputs = batch["text"]
# text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
mel_spec = batch["features"]
mel_lengths = batch["features_lens"]
if use_cosyvoice_semantic_token:
semantic_tokens = []
for i in range(len(batch["tokens"])):
tokens = batch["tokens"][i]
# tokens = insert_zeros_optimized(tokens)
tokens = interpolate_tokens(tokens)
semantic_tokens.append(tokens)
# pad to the same length, B,T, with pad value -1
max_len = max([len(tokens) for tokens in semantic_tokens])
text_inputs = torch.full((len(semantic_tokens), max_len), -1, dtype=torch.long).to(
device
)
text_inputs = torch.full(
(len(semantic_tokens), max_len), -1, dtype=torch.long
).to(device)
for i, tokens in enumerate(semantic_tokens):
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long)
else:
text_inputs = batch["text"]
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
@ -619,7 +644,11 @@ def compute_loss(
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
(text_inputs, mel_spec, mel_lengths) = prepare_input(
batch,
device=device,
use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token,
)
# at entry, TextTokens is (N, P)
with torch.set_grad_enabled(is_training):