diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py index 02e5f0f4d..beccd3c8d 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -22,6 +22,7 @@ import random import time from pathlib import Path +import datasets import torch import torch.nn.functional as F import torchaudio @@ -36,6 +37,7 @@ from train import ( add_model_arguments, get_model, get_tokenizer, + insert_zeros_optimized, load_F5_TTS_pretrained_checkpoint, ) @@ -78,7 +80,7 @@ def get_parser(): parser.add_argument( "--manifest-file", type=str, - default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst", + default=None, help="The manifest file in seed_tts_eval format", ) @@ -90,6 +92,21 @@ def get_parser(): ) parser.add_argument("-ss", "--swaysampling", default=-1, type=float) + + parser.add_argument( + "--insert-zero", + action="store_true", + help="Insert zeros for CosyVoice", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="huggingface dataset split name", + ) + add_model_arguments(parser) return parser.parse_args() @@ -243,6 +260,344 @@ def get_inference_prompt( return prompts_all +def get_inference_prompt_cosy_voice_huggingface( + dataset, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, + insert_zero=False, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for i in range(len(dataset)): + utt = dataset[i]["id"] + ref_audio_org, ref_sr = ( + dataset[i]["prompt_audio"]["array"], + dataset[i]["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() + audio_tokens = dataset[i]["target_audio_cosy2_tokens"] + prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"] + + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + input_tokens = prompt_audio_tokens + audio_tokens + + if insert_zero: + input_tokens = insert_zeros_optimized(input_tokens) + text_list = input_tokens + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + + total_mel_len = len(input_tokens) + if not insert_zero: + total_mel_len = int(total_mel_len / 4 * 15) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + if total_mel_len > max_tokens: + print( + f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + ) + continue + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + # final_text_list[bucket_i].extend(text_list) + final_text_list[bucket_i].append(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, + insert_zero=False, +): + + import sys + + sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + sys.path.append("/workspace/CosyVoice") + from cosyvoice.cli.cosyvoice import CosyVoice2 + + cosyvoice = CosyVoice2( + "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False + ) + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio_org, ref_sr = torchaudio.load(prompt_wav) + + # cosy voice + if ref_sr != 16000: + resampler = torchaudio.transforms.Resample(ref_sr, 16000) + ref_audio_16k = resampler(ref_audio_org) + else: + ref_audio_16k = ref_audio_org + audio_tokens, prompt_audio_tokens = cosyvoice.inference_speech_token( + gt_text, prompt_text, ref_audio_16k, stream=False + ) + + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + assert ( + ref_audio_org.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + # Text + # if len(prompt_text[-1].encode("utf-8")) == 1: + # prompt_text = prompt_text + " " + # text = [prompt_text + gt_text] + # if tokenizer == "pinyin": + # text_list = convert_char_to_pinyin(text, polyphone=polyphone) + # else: + # text_list = text + + # concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens + # prompt_audio_tokens shape 1, prompt_audio_tokens + # audio_tokens shape 1, audio_tokens + prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist() + input_tokens = prompt_audio_tokens + audio_tokens + + # convert it into a list + # input_tokens_list = input_tokens.squeeze().cpu().tolist() + if insert_zero: + input_tokens = insert_zeros_optimized(input_tokens) + text_list = input_tokens + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len_compute = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + total_mel_len = len(input_tokens) + if not insert_zero: + total_mel_len = int(total_mel_len / 4 * 15) + print( + f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + # final_text_list[bucket_i].extend(text_list) + final_text_list[bucket_i].append(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + def padded_mel_batch(ref_mels): max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() padded_ref_mels = [] @@ -275,20 +630,53 @@ def main(): accelerator = Accelerator() device = f"cuda:{accelerator.process_index}" + if args.manifest_file: + metainfo = get_seedtts_testset_metainfo(args.manifest_file) + # prompts_all = get_inference_prompt( + # metainfo, + # speed=1.0, + # tokenizer="pinyin", + # target_sample_rate=24_000, + # n_mel_channels=100, + # hop_length=256, + # mel_spec_type="bigvgan", + # target_rms=0.1, + # use_truth_duration=False, + # infer_batch_size=1, + # ) - metainfo = get_seedtts_testset_metainfo(args.manifest_file) - prompts_all = get_inference_prompt( - metainfo, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - ) + prompts_all = get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + insert_zero=args.insert_zero, + ) + else: + dataset = datasets.load_dataset( + "yuekai/seed_tts_cosy2", + split=args.split_name, + trust_remote_code=True, + ) + prompts_all = get_inference_prompt_cosy_voice_huggingface( + dataset, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + insert_zero=args.insert_zero, + ) vocoder = BigVGANInference.from_pretrained( "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False @@ -324,6 +712,15 @@ def main(): ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + # concat final_text_list + max_len = max([len(tokens) for tokens in final_text_list]) + # pad tokens to the same length + for i, tokens in enumerate(final_text_list): + final_text_list[i] = torch.tensor( + tokens + [-1] * (max_len - len(tokens)), dtype=torch.long + ) + final_text_list = torch.stack(final_text_list).to(device) + # Inference with torch.inference_mode(): generated, _ = model.sample( diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 61b1c709c..7a22b455f 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -580,7 +580,7 @@ def prepare_input(batch: dict, device: torch.device): semantic_tokens = [] for i in range(len(batch["tokens"])): tokens = batch["tokens"][i] - tokens = insert_zeros_optimized(tokens) + # tokens = insert_zeros_optimized(tokens) semantic_tokens.append(tokens) # pad to the same length, B,T, with pad value -1 max_len = max([len(tokens) for tokens in semantic_tokens]) diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh index b1cc4eb10..cf86b3fa5 100755 --- a/egs/wenetspeech4tts/TTS/prepare.sh +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -130,9 +130,6 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then data/fbank/${prefix}_cuts_validtest.jsonl.gz \ data/fbank/${prefix}_cuts_test.jsonl.gz - - # zcat "data/fbank/${prefix}_cuts_${subset}.jsonl.gz" | head -n 100 | gzip > "data/fbank/${prefix}_cuts_${subset}_top100.jsonl.gz" - rm data/fbank/${prefix}_cuts_validtest.jsonl.gz n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))