mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
update results
This commit is contained in:
parent
a54a0469a2
commit
2edaf685e1
@ -1,3 +1,10 @@
|
||||
# Results
|
||||
| Model | Seed-TTS test_zh CER | Comment |
|
||||
|---------------------------------------|---------------------|--------|
|
||||
| [vall-e](./valle) | 4.33% | ~150M |
|
||||
| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M |
|
||||
| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M |
|
||||
|
||||
# Introduction
|
||||
|
||||
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
|
||||
@ -131,6 +138,53 @@ accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-f
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
||||
```
|
||||
|
||||
# F5-TTS-Semantic-Token
|
||||
|
||||
./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens.
|
||||
|
||||
We observed faster convergence and better prosody modeling results by doing this.
|
||||
|
||||
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main).
|
||||
|
||||
Preparation:
|
||||
|
||||
```
|
||||
# extract cosyvoice2 semantic tokens
|
||||
bash prepare.sh --stage 5 --stop_stage 7
|
||||
```
|
||||
|
||||
The training command is given below:
|
||||
|
||||
```
|
||||
# docker: ghcr.io/swivid/f5-tts:main
|
||||
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
|
||||
|
||||
world_size=8
|
||||
exp_dir=exp/f5-tts-semantic-token-small
|
||||
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
|
||||
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
|
||||
--num-epochs 10 --start-epoch 1 --start-batch 0 \
|
||||
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size} \
|
||||
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
|
||||
```
|
||||
|
||||
To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use:
|
||||
```
|
||||
huggingface-cli login
|
||||
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
|
||||
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||
|
||||
split=test_zh
|
||||
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
|
||||
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
||||
```
|
||||
|
||||
# Credits
|
||||
- [VALL-E](https://github.com/lifeiteng/vall-e)
|
||||
- [F5-TTS](https://github.com/SWivid/F5-TTS)
|
||||
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
||||
|
@ -11,7 +11,14 @@ python3 f5-tts/generate_averaged_model.py \
|
||||
--epoch 56 \
|
||||
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||
--exp-dir exp/f5_small
|
||||
|
||||
# command for text token input
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
|
||||
|
||||
# command for cosyvoice semantic token input
|
||||
split=test_zh # seed_tts_eval test_zh
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
|
||||
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
||||
"""
|
||||
import argparse
|
||||
@ -37,11 +44,12 @@ from train import (
|
||||
add_model_arguments,
|
||||
get_model,
|
||||
get_tokenizer,
|
||||
insert_zeros_optimized,
|
||||
interpolate_tokens,
|
||||
load_F5_TTS_pretrained_checkpoint,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -94,9 +102,17 @@ def get_parser():
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument(
|
||||
"--insert-zero",
|
||||
action="store_true",
|
||||
help="Insert zeros for CosyVoice",
|
||||
"--interpolate-token",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Interpolate semantic token to match mel frames for CosyVoice",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-cosyvoice-semantic-token",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use cosyvoice semantic token to replace text token.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -277,7 +293,7 @@ def get_inference_prompt_cosy_voice_huggingface(
|
||||
num_buckets=200,
|
||||
min_secs=3,
|
||||
max_secs=40,
|
||||
insert_zero=False,
|
||||
interpolate_token=False,
|
||||
):
|
||||
prompts_all = []
|
||||
|
||||
@ -319,15 +335,15 @@ def get_inference_prompt_cosy_voice_huggingface(
|
||||
ref_audio = ref_audio_org
|
||||
input_tokens = prompt_audio_tokens + audio_tokens
|
||||
|
||||
if insert_zero:
|
||||
input_tokens = insert_zeros_optimized(input_tokens)
|
||||
if interpolate_token:
|
||||
input_tokens = interpolate_tokens(input_tokens)
|
||||
text_list = input_tokens
|
||||
|
||||
# Duration, mel frame length
|
||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||
|
||||
total_mel_len = len(input_tokens)
|
||||
if not insert_zero:
|
||||
if not interpolate_token:
|
||||
total_mel_len = int(total_mel_len / 4 * 15)
|
||||
|
||||
# to mel spectrogram
|
||||
@ -406,6 +422,51 @@ def get_inference_prompt_cosy_voice_huggingface(
|
||||
return prompts_all
|
||||
|
||||
|
||||
def inference_speech_token(
|
||||
cosyvoice,
|
||||
tts_text,
|
||||
prompt_text,
|
||||
prompt_speech_16k,
|
||||
stream=False,
|
||||
speed=1.0,
|
||||
text_frontend=True,
|
||||
):
|
||||
tokens = []
|
||||
prompt_text = cosyvoice.frontend.text_normalize(
|
||||
prompt_text, split=False, text_frontend=text_frontend
|
||||
)
|
||||
for i in cosyvoice.frontend.text_normalize(
|
||||
tts_text, split=True, text_frontend=text_frontend
|
||||
):
|
||||
|
||||
tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i)
|
||||
(
|
||||
prompt_text_token,
|
||||
prompt_text_token_len,
|
||||
) = cosyvoice.frontend._extract_text_token(prompt_text)
|
||||
speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token(
|
||||
prompt_speech_16k
|
||||
)
|
||||
|
||||
for i in cosyvoice.model.llm.inference(
|
||||
text=tts_text_token.to(cosyvoice.model.device),
|
||||
text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(
|
||||
cosyvoice.model.device
|
||||
),
|
||||
prompt_text=prompt_text_token.to(cosyvoice.model.device),
|
||||
prompt_text_len=torch.tensor(
|
||||
[prompt_text_token.shape[1]], dtype=torch.int32
|
||||
).to(cosyvoice.model.device),
|
||||
prompt_speech_token=speech_token.to(cosyvoice.model.device),
|
||||
prompt_speech_token_len=torch.tensor(
|
||||
[speech_token.shape[1]], dtype=torch.int32
|
||||
).to(cosyvoice.model.device),
|
||||
embedding=None,
|
||||
):
|
||||
tokens.append(i)
|
||||
return tokens, speech_token
|
||||
|
||||
|
||||
def get_inference_prompt_cosy_voice(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
@ -423,18 +484,21 @@ def get_inference_prompt_cosy_voice(
|
||||
num_buckets=200,
|
||||
min_secs=3,
|
||||
max_secs=40,
|
||||
insert_zero=False,
|
||||
interpolate_token=False,
|
||||
):
|
||||
|
||||
import sys
|
||||
|
||||
# please change the path to the cosyvoice accordingly
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
sys.path.append("/workspace/CosyVoice")
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
|
||||
# please download the cosyvoice model first
|
||||
cosyvoice = CosyVoice2(
|
||||
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
|
||||
prompts_all = []
|
||||
|
||||
min_tokens = min_secs * target_sample_rate // hop_length
|
||||
@ -466,8 +530,8 @@ def get_inference_prompt_cosy_voice(
|
||||
ref_audio_16k = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio_16k = ref_audio_org
|
||||
audio_tokens, prompt_audio_tokens = cosyvoice.inference_speech_token(
|
||||
gt_text, prompt_text, ref_audio_16k, stream=False
|
||||
audio_tokens, prompt_audio_tokens = inference_speech_token(
|
||||
cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False
|
||||
)
|
||||
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||
@ -499,8 +563,8 @@ def get_inference_prompt_cosy_voice(
|
||||
|
||||
# convert it into a list
|
||||
# input_tokens_list = input_tokens.squeeze().cpu().tolist()
|
||||
if insert_zero:
|
||||
input_tokens = insert_zeros_optimized(input_tokens)
|
||||
if interpolate_token:
|
||||
input_tokens = interpolate_tokens(input_tokens)
|
||||
text_list = input_tokens
|
||||
|
||||
# Duration, mel frame length
|
||||
@ -521,7 +585,7 @@ def get_inference_prompt_cosy_voice(
|
||||
ref_mel_len / ref_text_len * gen_text_len / speed
|
||||
)
|
||||
total_mel_len = len(input_tokens)
|
||||
if not insert_zero:
|
||||
if not interpolate_token:
|
||||
total_mel_len = int(total_mel_len / 4 * 15)
|
||||
print(
|
||||
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
|
||||
@ -632,33 +696,35 @@ def main():
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
if args.manifest_file:
|
||||
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
||||
# prompts_all = get_inference_prompt(
|
||||
# metainfo,
|
||||
# speed=1.0,
|
||||
# tokenizer="pinyin",
|
||||
# target_sample_rate=24_000,
|
||||
# n_mel_channels=100,
|
||||
# hop_length=256,
|
||||
# mel_spec_type="bigvgan",
|
||||
# target_rms=0.1,
|
||||
# use_truth_duration=False,
|
||||
# infer_batch_size=1,
|
||||
# )
|
||||
|
||||
prompts_all = get_inference_prompt_cosy_voice(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
target_sample_rate=24_000,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
insert_zero=args.insert_zero,
|
||||
)
|
||||
if not args.use_cosyvoice_semantic_token:
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
target_sample_rate=24_000,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
)
|
||||
else:
|
||||
prompts_all = get_inference_prompt_cosy_voice(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
target_sample_rate=24_000,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
interpolate_token=args.interpolate_token,
|
||||
)
|
||||
else:
|
||||
assert args.use_cosyvoice_semantic_token
|
||||
dataset = datasets.load_dataset(
|
||||
"yuekai/seed_tts_cosy2",
|
||||
split=args.split_name,
|
||||
@ -675,7 +741,7 @@ def main():
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
insert_zero=args.insert_zero,
|
||||
interpolate_token=args.interpolate_token,
|
||||
)
|
||||
|
||||
vocoder = BigVGANInference.from_pretrained(
|
||||
@ -712,14 +778,15 @@ def main():
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
|
||||
# concat final_text_list
|
||||
max_len = max([len(tokens) for tokens in final_text_list])
|
||||
# pad tokens to the same length
|
||||
for i, tokens in enumerate(final_text_list):
|
||||
final_text_list[i] = torch.tensor(
|
||||
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
||||
)
|
||||
final_text_list = torch.stack(final_text_list).to(device)
|
||||
if args.use_cosyvoice_semantic_token:
|
||||
# concat final_text_list
|
||||
max_len = max([len(tokens) for tokens in final_text_list])
|
||||
# pad tokens to the same length
|
||||
for i, tokens in enumerate(final_text_list):
|
||||
final_text_list[i] = torch.tensor(
|
||||
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
||||
)
|
||||
final_text_list = torch.stack(final_text_list).to(device)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
|
@ -31,6 +31,16 @@ python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-ma
|
||||
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
|
||||
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||
|
||||
# command for training with cosyvoice semantic token
|
||||
exp_dir=exp/f5-tts-cosyvoice
|
||||
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
|
||||
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
|
||||
--num-epochs 10 --start-epoch 1 --start-batch 0 \
|
||||
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size} \
|
||||
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -303,6 +313,13 @@ def get_parser():
|
||||
help="perform OOM check on dataloader batches before starting training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-cosyvoice-semantic-token",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use cosyvoice semantic token to replace text token.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -378,9 +395,11 @@ def get_tokenizer(vocab_file_path: str):
|
||||
|
||||
|
||||
def get_model(params):
|
||||
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
|
||||
vocab_char_map, vocab_size = None, 6561
|
||||
# https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36
|
||||
if params.use_cosyvoice_semantic_token:
|
||||
# https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36
|
||||
vocab_char_map, vocab_size = None, 6561
|
||||
else:
|
||||
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
|
||||
# bigvgan 100 dim features
|
||||
n_mel_channels = 100
|
||||
n_fft = 1024
|
||||
@ -558,37 +577,43 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def insert_zeros_optimized(arr):
|
||||
def interpolate_tokens(cosy_tokens, pad_token=-1):
|
||||
"""Interpolate cosyvoice tokens to match bigvgan frames length"""
|
||||
# cosyvoice, 25 tokens/sec
|
||||
# bigvgan sample_rate/hop_length 24000/256 frames/sec
|
||||
# For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length
|
||||
# We choose 4,4,4,3 to match 15 frames
|
||||
three, two = [-1] * 3, [-1] * 2
|
||||
three, two = [pad_token] * 3, [pad_token] * 2
|
||||
return [
|
||||
x for i, e in enumerate(arr) for x in ([e] + three if i % 4 < 3 else [e] + two)
|
||||
x
|
||||
for i, e in enumerate(cosy_tokens)
|
||||
for x in ([e] + three if i % 4 < 3 else [e] + two)
|
||||
]
|
||||
|
||||
|
||||
def prepare_input(batch: dict, device: torch.device):
|
||||
def prepare_input(
|
||||
batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool
|
||||
):
|
||||
"""Parse batch data"""
|
||||
# text_inputs = batch["text"]
|
||||
# text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||
|
||||
mel_spec = batch["features"]
|
||||
mel_lengths = batch["features_lens"]
|
||||
|
||||
semantic_tokens = []
|
||||
for i in range(len(batch["tokens"])):
|
||||
tokens = batch["tokens"][i]
|
||||
# tokens = insert_zeros_optimized(tokens)
|
||||
semantic_tokens.append(tokens)
|
||||
# pad to the same length, B,T, with pad value -1
|
||||
max_len = max([len(tokens) for tokens in semantic_tokens])
|
||||
text_inputs = torch.full((len(semantic_tokens), max_len), -1, dtype=torch.long).to(
|
||||
device
|
||||
)
|
||||
for i, tokens in enumerate(semantic_tokens):
|
||||
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long)
|
||||
if use_cosyvoice_semantic_token:
|
||||
semantic_tokens = []
|
||||
for i in range(len(batch["tokens"])):
|
||||
tokens = batch["tokens"][i]
|
||||
tokens = interpolate_tokens(tokens)
|
||||
semantic_tokens.append(tokens)
|
||||
# pad to the same length, B,T, with pad value -1
|
||||
max_len = max([len(tokens) for tokens in semantic_tokens])
|
||||
text_inputs = torch.full(
|
||||
(len(semantic_tokens), max_len), -1, dtype=torch.long
|
||||
).to(device)
|
||||
for i, tokens in enumerate(semantic_tokens):
|
||||
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long)
|
||||
else:
|
||||
text_inputs = batch["text"]
|
||||
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||
|
||||
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
||||
|
||||
@ -619,7 +644,11 @@ def compute_loss(
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
|
||||
(text_inputs, mel_spec, mel_lengths) = prepare_input(
|
||||
batch,
|
||||
device=device,
|
||||
use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token,
|
||||
)
|
||||
# at entry, TextTokens is (N, P)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
|
Loading…
x
Reference in New Issue
Block a user