diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md index d47687acd..cb41c67c7 100644 --- a/egs/wenetspeech4tts/TTS/README.md +++ b/egs/wenetspeech4tts/TTS/README.md @@ -1,3 +1,10 @@ +# Results +| Model | Seed-TTS test_zh CER | Comment | +|---------------------------------------|---------------------|--------| +| [vall-e](./valle) | 4.33% | ~150M | +| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M | +| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M | + # Introduction [**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. @@ -131,6 +138,53 @@ accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-f bash local/compute_wer.sh $output_dir $manifest ``` +# F5-TTS-Semantic-Token + +./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens. + +We observed faster convergence and better prosody modeling results by doing this. + +Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main). + +Preparation: + +``` +# extract cosyvoice2 semantic tokens +bash prepare.sh --stage 5 --stop_stage 7 +``` + +The training command is given below: + +``` +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece + +world_size=8 +exp_dir=exp/f5-tts-semantic-token-small +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 1e-4 --warmup-steps 20000 --average-period 0 \ + --num-epochs 10 --start-epoch 1 --start-batch 0 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} \ + --decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True +``` + +To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use: +``` +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x + +split=test_zh +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt + +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True +bash local/compute_wer.sh $output_dir $manifest +``` + # Credits - [VALL-E](https://github.com/lifeiteng/vall-e) - [F5-TTS](https://github.com/SWivid/F5-TTS) +- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py index beccd3c8d..6964a43be 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -11,7 +11,14 @@ python3 f5-tts/generate_averaged_model.py \ --epoch 56 \ --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ --exp-dir exp/f5_small + +# command for text token input accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 + +# command for cosyvoice semantic token input +split=test_zh # seed_tts_eval test_zh +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True + bash local/compute_wer.sh $output_dir $manifest """ import argparse @@ -37,11 +44,12 @@ from train import ( add_model_arguments, get_model, get_tokenizer, - insert_zeros_optimized, + interpolate_tokens, load_F5_TTS_pretrained_checkpoint, ) from icefall.checkpoint import load_checkpoint +from icefall.utils import str2bool def get_parser(): @@ -94,9 +102,17 @@ def get_parser(): parser.add_argument("-ss", "--swaysampling", default=-1, type=float) parser.add_argument( - "--insert-zero", - action="store_true", - help="Insert zeros for CosyVoice", + "--interpolate-token", + type=str2bool, + default=True, + help="Interpolate semantic token to match mel frames for CosyVoice", + ) + + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", ) parser.add_argument( @@ -277,7 +293,7 @@ def get_inference_prompt_cosy_voice_huggingface( num_buckets=200, min_secs=3, max_secs=40, - insert_zero=False, + interpolate_token=False, ): prompts_all = [] @@ -319,15 +335,15 @@ def get_inference_prompt_cosy_voice_huggingface( ref_audio = ref_audio_org input_tokens = prompt_audio_tokens + audio_tokens - if insert_zero: - input_tokens = insert_zeros_optimized(input_tokens) + if interpolate_token: + input_tokens = interpolate_tokens(input_tokens) text_list = input_tokens # Duration, mel frame length ref_mel_len = ref_audio.shape[-1] // hop_length total_mel_len = len(input_tokens) - if not insert_zero: + if not interpolate_token: total_mel_len = int(total_mel_len / 4 * 15) # to mel spectrogram @@ -406,6 +422,51 @@ def get_inference_prompt_cosy_voice_huggingface( return prompts_all +def inference_speech_token( + cosyvoice, + tts_text, + prompt_text, + prompt_speech_16k, + stream=False, + speed=1.0, + text_frontend=True, +): + tokens = [] + prompt_text = cosyvoice.frontend.text_normalize( + prompt_text, split=False, text_frontend=text_frontend + ) + for i in cosyvoice.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ): + + tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i) + ( + prompt_text_token, + prompt_text_token_len, + ) = cosyvoice.frontend._extract_text_token(prompt_text) + speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token( + prompt_speech_16k + ) + + for i in cosyvoice.model.llm.inference( + text=tts_text_token.to(cosyvoice.model.device), + text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to( + cosyvoice.model.device + ), + prompt_text=prompt_text_token.to(cosyvoice.model.device), + prompt_text_len=torch.tensor( + [prompt_text_token.shape[1]], dtype=torch.int32 + ).to(cosyvoice.model.device), + prompt_speech_token=speech_token.to(cosyvoice.model.device), + prompt_speech_token_len=torch.tensor( + [speech_token.shape[1]], dtype=torch.int32 + ).to(cosyvoice.model.device), + embedding=None, + ): + tokens.append(i) + return tokens, speech_token + + def get_inference_prompt_cosy_voice( metainfo, speed=1.0, @@ -423,18 +484,21 @@ def get_inference_prompt_cosy_voice( num_buckets=200, min_secs=3, max_secs=40, - insert_zero=False, + interpolate_token=False, ): import sys + # please change the path to the cosyvoice accordingly sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") sys.path.append("/workspace/CosyVoice") from cosyvoice.cli.cosyvoice import CosyVoice2 + # please download the cosyvoice model first cosyvoice = CosyVoice2( "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False ) + prompts_all = [] min_tokens = min_secs * target_sample_rate // hop_length @@ -466,8 +530,8 @@ def get_inference_prompt_cosy_voice( ref_audio_16k = resampler(ref_audio_org) else: ref_audio_16k = ref_audio_org - audio_tokens, prompt_audio_tokens = cosyvoice.inference_speech_token( - gt_text, prompt_text, ref_audio_16k, stream=False + audio_tokens, prompt_audio_tokens = inference_speech_token( + cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False ) ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) @@ -499,8 +563,8 @@ def get_inference_prompt_cosy_voice( # convert it into a list # input_tokens_list = input_tokens.squeeze().cpu().tolist() - if insert_zero: - input_tokens = insert_zeros_optimized(input_tokens) + if interpolate_token: + input_tokens = interpolate_tokens(input_tokens) text_list = input_tokens # Duration, mel frame length @@ -521,7 +585,7 @@ def get_inference_prompt_cosy_voice( ref_mel_len / ref_text_len * gen_text_len / speed ) total_mel_len = len(input_tokens) - if not insert_zero: + if not interpolate_token: total_mel_len = int(total_mel_len / 4 * 15) print( f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" @@ -632,33 +696,35 @@ def main(): device = f"cuda:{accelerator.process_index}" if args.manifest_file: metainfo = get_seedtts_testset_metainfo(args.manifest_file) - # prompts_all = get_inference_prompt( - # metainfo, - # speed=1.0, - # tokenizer="pinyin", - # target_sample_rate=24_000, - # n_mel_channels=100, - # hop_length=256, - # mel_spec_type="bigvgan", - # target_rms=0.1, - # use_truth_duration=False, - # infer_batch_size=1, - # ) - - prompts_all = get_inference_prompt_cosy_voice( - metainfo, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - insert_zero=args.insert_zero, - ) + if not args.use_cosyvoice_semantic_token: + prompts_all = get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + ) + else: + prompts_all = get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + interpolate_token=args.interpolate_token, + ) else: + assert args.use_cosyvoice_semantic_token dataset = datasets.load_dataset( "yuekai/seed_tts_cosy2", split=args.split_name, @@ -675,7 +741,7 @@ def main(): target_rms=0.1, use_truth_duration=False, infer_batch_size=1, - insert_zero=args.insert_zero, + interpolate_token=args.interpolate_token, ) vocoder = BigVGANInference.from_pretrained( @@ -712,14 +778,15 @@ def main(): ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) - # concat final_text_list - max_len = max([len(tokens) for tokens in final_text_list]) - # pad tokens to the same length - for i, tokens in enumerate(final_text_list): - final_text_list[i] = torch.tensor( - tokens + [-1] * (max_len - len(tokens)), dtype=torch.long - ) - final_text_list = torch.stack(final_text_list).to(device) + if args.use_cosyvoice_semantic_token: + # concat final_text_list + max_len = max([len(tokens) for tokens in final_text_list]) + # pad tokens to the same length + for i, tokens in enumerate(final_text_list): + final_text_list[i] = torch.tensor( + tokens + [-1] * (max_len - len(tokens)), dtype=torch.long + ) + final_text_list = torch.stack(final_text_list).to(device) # Inference with torch.inference_mode(): diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 7a22b455f..5333b3f27 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -31,6 +31,16 @@ python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-ma --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ --exp-dir ${exp_dir} --world-size ${world_size} + +# command for training with cosyvoice semantic token +exp_dir=exp/f5-tts-cosyvoice +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 1e-4 --warmup-steps 20000 --average-period 0 \ + --num-epochs 10 --start-epoch 1 --start-batch 0 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} \ + --decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True """ import argparse @@ -303,6 +313,13 @@ def get_parser(): help="perform OOM check on dataloader batches before starting training.", ) + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", + ) + add_model_arguments(parser) return parser @@ -378,9 +395,11 @@ def get_tokenizer(vocab_file_path: str): def get_model(params): - vocab_char_map, vocab_size = get_tokenizer(params.tokens) - vocab_char_map, vocab_size = None, 6561 - # https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36 + if params.use_cosyvoice_semantic_token: + # https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36 + vocab_char_map, vocab_size = None, 6561 + else: + vocab_char_map, vocab_size = get_tokenizer(params.tokens) # bigvgan 100 dim features n_mel_channels = 100 n_fft = 1024 @@ -558,37 +577,43 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) -def insert_zeros_optimized(arr): +def interpolate_tokens(cosy_tokens, pad_token=-1): + """Interpolate cosyvoice tokens to match bigvgan frames length""" # cosyvoice, 25 tokens/sec # bigvgan sample_rate/hop_length 24000/256 frames/sec # For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length # We choose 4,4,4,3 to match 15 frames - three, two = [-1] * 3, [-1] * 2 + three, two = [pad_token] * 3, [pad_token] * 2 return [ - x for i, e in enumerate(arr) for x in ([e] + three if i % 4 < 3 else [e] + two) + x + for i, e in enumerate(cosy_tokens) + for x in ([e] + three if i % 4 < 3 else [e] + two) ] -def prepare_input(batch: dict, device: torch.device): +def prepare_input( + batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool +): """Parse batch data""" - # text_inputs = batch["text"] - # text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) - mel_spec = batch["features"] mel_lengths = batch["features_lens"] - semantic_tokens = [] - for i in range(len(batch["tokens"])): - tokens = batch["tokens"][i] - # tokens = insert_zeros_optimized(tokens) - semantic_tokens.append(tokens) - # pad to the same length, B,T, with pad value -1 - max_len = max([len(tokens) for tokens in semantic_tokens]) - text_inputs = torch.full((len(semantic_tokens), max_len), -1, dtype=torch.long).to( - device - ) - for i, tokens in enumerate(semantic_tokens): - text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) + if use_cosyvoice_semantic_token: + semantic_tokens = [] + for i in range(len(batch["tokens"])): + tokens = batch["tokens"][i] + tokens = interpolate_tokens(tokens) + semantic_tokens.append(tokens) + # pad to the same length, B,T, with pad value -1 + max_len = max([len(tokens) for tokens in semantic_tokens]) + text_inputs = torch.full( + (len(semantic_tokens), max_len), -1, dtype=torch.long + ).to(device) + for i, tokens in enumerate(semantic_tokens): + text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) + else: + text_inputs = batch["text"] + text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) return text_inputs, mel_spec.to(device), mel_lengths.to(device) @@ -619,7 +644,11 @@ def compute_loss( values >= 1.0 are fully warmed up and have all modules present. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device - (text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device) + (text_inputs, mel_spec, mel_lengths) = prepare_input( + batch, + device=device, + use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token, + ) # at entry, TextTokens is (N, P) with torch.set_grad_enabled(is_training):