update results

This commit is contained in:
root 2025-02-20 08:48:11 +00:00
parent a54a0469a2
commit 2edaf685e1
3 changed files with 222 additions and 72 deletions

View File

@ -1,3 +1,10 @@
# Results
| Model | Seed-TTS test_zh CER | Comment |
|---------------------------------------|---------------------|--------|
| [vall-e](./valle) | 4.33% | ~150M |
| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M |
| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M |
# Introduction # Introduction
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. [**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
@ -131,6 +138,53 @@ accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-f
bash local/compute_wer.sh $output_dir $manifest bash local/compute_wer.sh $output_dir $manifest
``` ```
# F5-TTS-Semantic-Token
./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens.
We observed faster convergence and better prosody modeling results by doing this.
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main).
Preparation:
```
# extract cosyvoice2 semantic tokens
bash prepare.sh --stage 5 --stop_stage 7
```
The training command is given below:
```
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
world_size=8
exp_dir=exp/f5-tts-semantic-token-small
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
--num-epochs 10 --start-epoch 1 --start-batch 0 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size} \
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
```
To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use:
```
huggingface-cli login
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
split=test_zh
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
bash local/compute_wer.sh $output_dir $manifest
```
# Credits # Credits
- [VALL-E](https://github.com/lifeiteng/vall-e) - [VALL-E](https://github.com/lifeiteng/vall-e)
- [F5-TTS](https://github.com/SWivid/F5-TTS) - [F5-TTS](https://github.com/SWivid/F5-TTS)
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)

View File

@ -11,7 +11,14 @@ python3 f5-tts/generate_averaged_model.py \
--epoch 56 \ --epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small --exp-dir exp/f5_small
# command for text token input
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
# command for cosyvoice semantic token input
split=test_zh # seed_tts_eval test_zh
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
bash local/compute_wer.sh $output_dir $manifest bash local/compute_wer.sh $output_dir $manifest
""" """
import argparse import argparse
@ -37,11 +44,12 @@ from train import (
add_model_arguments, add_model_arguments,
get_model, get_model,
get_tokenizer, get_tokenizer,
insert_zeros_optimized, interpolate_tokens,
load_F5_TTS_pretrained_checkpoint, load_F5_TTS_pretrained_checkpoint,
) )
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -94,9 +102,17 @@ def get_parser():
parser.add_argument("-ss", "--swaysampling", default=-1, type=float) parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument( parser.add_argument(
"--insert-zero", "--interpolate-token",
action="store_true", type=str2bool,
help="Insert zeros for CosyVoice", default=True,
help="Interpolate semantic token to match mel frames for CosyVoice",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
) )
parser.add_argument( parser.add_argument(
@ -277,7 +293,7 @@ def get_inference_prompt_cosy_voice_huggingface(
num_buckets=200, num_buckets=200,
min_secs=3, min_secs=3,
max_secs=40, max_secs=40,
insert_zero=False, interpolate_token=False,
): ):
prompts_all = [] prompts_all = []
@ -319,15 +335,15 @@ def get_inference_prompt_cosy_voice_huggingface(
ref_audio = ref_audio_org ref_audio = ref_audio_org
input_tokens = prompt_audio_tokens + audio_tokens input_tokens = prompt_audio_tokens + audio_tokens
if insert_zero: if interpolate_token:
input_tokens = insert_zeros_optimized(input_tokens) input_tokens = interpolate_tokens(input_tokens)
text_list = input_tokens text_list = input_tokens
# Duration, mel frame length # Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length ref_mel_len = ref_audio.shape[-1] // hop_length
total_mel_len = len(input_tokens) total_mel_len = len(input_tokens)
if not insert_zero: if not interpolate_token:
total_mel_len = int(total_mel_len / 4 * 15) total_mel_len = int(total_mel_len / 4 * 15)
# to mel spectrogram # to mel spectrogram
@ -406,6 +422,51 @@ def get_inference_prompt_cosy_voice_huggingface(
return prompts_all return prompts_all
def inference_speech_token(
cosyvoice,
tts_text,
prompt_text,
prompt_speech_16k,
stream=False,
speed=1.0,
text_frontend=True,
):
tokens = []
prompt_text = cosyvoice.frontend.text_normalize(
prompt_text, split=False, text_frontend=text_frontend
)
for i in cosyvoice.frontend.text_normalize(
tts_text, split=True, text_frontend=text_frontend
):
tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i)
(
prompt_text_token,
prompt_text_token_len,
) = cosyvoice.frontend._extract_text_token(prompt_text)
speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token(
prompt_speech_16k
)
for i in cosyvoice.model.llm.inference(
text=tts_text_token.to(cosyvoice.model.device),
text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(
cosyvoice.model.device
),
prompt_text=prompt_text_token.to(cosyvoice.model.device),
prompt_text_len=torch.tensor(
[prompt_text_token.shape[1]], dtype=torch.int32
).to(cosyvoice.model.device),
prompt_speech_token=speech_token.to(cosyvoice.model.device),
prompt_speech_token_len=torch.tensor(
[speech_token.shape[1]], dtype=torch.int32
).to(cosyvoice.model.device),
embedding=None,
):
tokens.append(i)
return tokens, speech_token
def get_inference_prompt_cosy_voice( def get_inference_prompt_cosy_voice(
metainfo, metainfo,
speed=1.0, speed=1.0,
@ -423,18 +484,21 @@ def get_inference_prompt_cosy_voice(
num_buckets=200, num_buckets=200,
min_secs=3, min_secs=3,
max_secs=40, max_secs=40,
insert_zero=False, interpolate_token=False,
): ):
import sys import sys
# please change the path to the cosyvoice accordingly
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
sys.path.append("/workspace/CosyVoice") sys.path.append("/workspace/CosyVoice")
from cosyvoice.cli.cosyvoice import CosyVoice2 from cosyvoice.cli.cosyvoice import CosyVoice2
# please download the cosyvoice model first
cosyvoice = CosyVoice2( cosyvoice = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
) )
prompts_all = [] prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length min_tokens = min_secs * target_sample_rate // hop_length
@ -466,8 +530,8 @@ def get_inference_prompt_cosy_voice(
ref_audio_16k = resampler(ref_audio_org) ref_audio_16k = resampler(ref_audio_org)
else: else:
ref_audio_16k = ref_audio_org ref_audio_16k = ref_audio_org
audio_tokens, prompt_audio_tokens = cosyvoice.inference_speech_token( audio_tokens, prompt_audio_tokens = inference_speech_token(
gt_text, prompt_text, ref_audio_16k, stream=False cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False
) )
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
@ -499,8 +563,8 @@ def get_inference_prompt_cosy_voice(
# convert it into a list # convert it into a list
# input_tokens_list = input_tokens.squeeze().cpu().tolist() # input_tokens_list = input_tokens.squeeze().cpu().tolist()
if insert_zero: if interpolate_token:
input_tokens = insert_zeros_optimized(input_tokens) input_tokens = interpolate_tokens(input_tokens)
text_list = input_tokens text_list = input_tokens
# Duration, mel frame length # Duration, mel frame length
@ -521,7 +585,7 @@ def get_inference_prompt_cosy_voice(
ref_mel_len / ref_text_len * gen_text_len / speed ref_mel_len / ref_text_len * gen_text_len / speed
) )
total_mel_len = len(input_tokens) total_mel_len = len(input_tokens)
if not insert_zero: if not interpolate_token:
total_mel_len = int(total_mel_len / 4 * 15) total_mel_len = int(total_mel_len / 4 * 15)
print( print(
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
@ -632,33 +696,35 @@ def main():
device = f"cuda:{accelerator.process_index}" device = f"cuda:{accelerator.process_index}"
if args.manifest_file: if args.manifest_file:
metainfo = get_seedtts_testset_metainfo(args.manifest_file) metainfo = get_seedtts_testset_metainfo(args.manifest_file)
# prompts_all = get_inference_prompt( if not args.use_cosyvoice_semantic_token:
# metainfo, prompts_all = get_inference_prompt(
# speed=1.0, metainfo,
# tokenizer="pinyin", speed=1.0,
# target_sample_rate=24_000, tokenizer="pinyin",
# n_mel_channels=100, target_sample_rate=24_000,
# hop_length=256, n_mel_channels=100,
# mel_spec_type="bigvgan", hop_length=256,
# target_rms=0.1, mel_spec_type="bigvgan",
# use_truth_duration=False, target_rms=0.1,
# infer_batch_size=1, use_truth_duration=False,
# ) infer_batch_size=1,
)
prompts_all = get_inference_prompt_cosy_voice( else:
metainfo, prompts_all = get_inference_prompt_cosy_voice(
speed=1.0, metainfo,
tokenizer="pinyin", speed=1.0,
target_sample_rate=24_000, tokenizer="pinyin",
n_mel_channels=100, target_sample_rate=24_000,
hop_length=256, n_mel_channels=100,
mel_spec_type="bigvgan", hop_length=256,
target_rms=0.1, mel_spec_type="bigvgan",
use_truth_duration=False, target_rms=0.1,
infer_batch_size=1, use_truth_duration=False,
insert_zero=args.insert_zero, infer_batch_size=1,
) interpolate_token=args.interpolate_token,
)
else: else:
assert args.use_cosyvoice_semantic_token
dataset = datasets.load_dataset( dataset = datasets.load_dataset(
"yuekai/seed_tts_cosy2", "yuekai/seed_tts_cosy2",
split=args.split_name, split=args.split_name,
@ -675,7 +741,7 @@ def main():
target_rms=0.1, target_rms=0.1,
use_truth_duration=False, use_truth_duration=False,
infer_batch_size=1, infer_batch_size=1,
insert_zero=args.insert_zero, interpolate_token=args.interpolate_token,
) )
vocoder = BigVGANInference.from_pretrained( vocoder = BigVGANInference.from_pretrained(
@ -712,14 +778,15 @@ def main():
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# concat final_text_list if args.use_cosyvoice_semantic_token:
max_len = max([len(tokens) for tokens in final_text_list]) # concat final_text_list
# pad tokens to the same length max_len = max([len(tokens) for tokens in final_text_list])
for i, tokens in enumerate(final_text_list): # pad tokens to the same length
final_text_list[i] = torch.tensor( for i, tokens in enumerate(final_text_list):
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long final_text_list[i] = torch.tensor(
) tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
final_text_list = torch.stack(final_text_list).to(device) )
final_text_list = torch.stack(final_text_list).to(device)
# Inference # Inference
with torch.inference_mode(): with torch.inference_mode():

View File

@ -31,6 +31,16 @@ python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-ma
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size} --exp-dir ${exp_dir} --world-size ${world_size}
# command for training with cosyvoice semantic token
exp_dir=exp/f5-tts-cosyvoice
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
--num-epochs 10 --start-epoch 1 --start-batch 0 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size} \
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
""" """
import argparse import argparse
@ -303,6 +313,13 @@ def get_parser():
help="perform OOM check on dataloader batches before starting training.", help="perform OOM check on dataloader batches before starting training.",
) )
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -378,9 +395,11 @@ def get_tokenizer(vocab_file_path: str):
def get_model(params): def get_model(params):
vocab_char_map, vocab_size = get_tokenizer(params.tokens) if params.use_cosyvoice_semantic_token:
vocab_char_map, vocab_size = None, 6561 # https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36
# https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36 vocab_char_map, vocab_size = None, 6561
else:
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
# bigvgan 100 dim features # bigvgan 100 dim features
n_mel_channels = 100 n_mel_channels = 100
n_fft = 1024 n_fft = 1024
@ -558,37 +577,43 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def insert_zeros_optimized(arr): def interpolate_tokens(cosy_tokens, pad_token=-1):
"""Interpolate cosyvoice tokens to match bigvgan frames length"""
# cosyvoice, 25 tokens/sec # cosyvoice, 25 tokens/sec
# bigvgan sample_rate/hop_length 24000/256 frames/sec # bigvgan sample_rate/hop_length 24000/256 frames/sec
# For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length # For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length
# We choose 4,4,4,3 to match 15 frames # We choose 4,4,4,3 to match 15 frames
three, two = [-1] * 3, [-1] * 2 three, two = [pad_token] * 3, [pad_token] * 2
return [ return [
x for i, e in enumerate(arr) for x in ([e] + three if i % 4 < 3 else [e] + two) x
for i, e in enumerate(cosy_tokens)
for x in ([e] + three if i % 4 < 3 else [e] + two)
] ]
def prepare_input(batch: dict, device: torch.device): def prepare_input(
batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool
):
"""Parse batch data""" """Parse batch data"""
# text_inputs = batch["text"]
# text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
mel_spec = batch["features"] mel_spec = batch["features"]
mel_lengths = batch["features_lens"] mel_lengths = batch["features_lens"]
semantic_tokens = [] if use_cosyvoice_semantic_token:
for i in range(len(batch["tokens"])): semantic_tokens = []
tokens = batch["tokens"][i] for i in range(len(batch["tokens"])):
# tokens = insert_zeros_optimized(tokens) tokens = batch["tokens"][i]
semantic_tokens.append(tokens) tokens = interpolate_tokens(tokens)
# pad to the same length, B,T, with pad value -1 semantic_tokens.append(tokens)
max_len = max([len(tokens) for tokens in semantic_tokens]) # pad to the same length, B,T, with pad value -1
text_inputs = torch.full((len(semantic_tokens), max_len), -1, dtype=torch.long).to( max_len = max([len(tokens) for tokens in semantic_tokens])
device text_inputs = torch.full(
) (len(semantic_tokens), max_len), -1, dtype=torch.long
for i, tokens in enumerate(semantic_tokens): ).to(device)
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) for i, tokens in enumerate(semantic_tokens):
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long)
else:
text_inputs = batch["text"]
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
return text_inputs, mel_spec.to(device), mel_lengths.to(device) return text_inputs, mel_spec.to(device), mel_lengths.to(device)
@ -619,7 +644,11 @@ def compute_loss(
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device) (text_inputs, mel_spec, mel_lengths) = prepare_input(
batch,
device=device,
use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token,
)
# at entry, TextTokens is (N, P) # at entry, TextTokens is (N, P)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):