import argparse import math import os import random import time # import bigvan # sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/") import torch import torch.nn.functional as F import torchaudio from accelerate import Accelerator # from importlib.resources import files # import sys # sys.path.append(f"/home/yuekaiz/BigVGAN/") # from bigvgan import BigVGAN from bigvganinference import BigVGANInference # from f5_tts.eval.utils_eval import ( # get_inference_prompt, # get_librispeech_test_clean_metainfo, # get_seedtts_testset_metainfo, # ) # from f5_tts.infer.utils_infer import load_vocoder from model.cfm import CFM from model.dit import DiT from model.modules import MelSpec from model.utils import convert_char_to_pinyin from tqdm import tqdm from train import get_tokenizer, load_pretrained_checkpoint from icefall.checkpoint import load_checkpoint def load_vocoder(device): # huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x model = BigVGANInference.from_pretrained( "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False ) model = model.eval().to(device) return model def get_inference_prompt( metainfo, speed=1.0, tokenizer="pinyin", polyphone=True, target_sample_rate=24000, n_fft=1024, win_length=1024, n_mel_channels=100, hop_length=256, mel_spec_type="vocos", target_rms=0.1, use_truth_duration=False, infer_batch_size=1, num_buckets=200, min_secs=3, max_secs=40, ): prompts_all = [] min_tokens = min_secs * target_sample_rate // hop_length max_tokens = max_secs * target_sample_rate // hop_length batch_accum = [0] * num_buckets utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( [[] for _ in range(num_buckets)] for _ in range(6) ) mel_spectrogram = MelSpec( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ) for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( metainfo, desc="Processing prompts..." ): # Audio ref_audio, ref_sr = torchaudio.load(prompt_wav) ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) if ref_rms < target_rms: ref_audio = ref_audio * target_rms / ref_rms assert ( ref_audio.shape[-1] > 5000 ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." if ref_sr != target_sample_rate: resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) ref_audio = resampler(ref_audio) # Text if len(prompt_text[-1].encode("utf-8")) == 1: prompt_text = prompt_text + " " text = [prompt_text + gt_text] if tokenizer == "pinyin": text_list = convert_char_to_pinyin(text, polyphone=polyphone) else: text_list = text # Duration, mel frame length ref_mel_len = ref_audio.shape[-1] // hop_length if use_truth_duration: gt_audio, gt_sr = torchaudio.load(gt_wav) if gt_sr != target_sample_rate: resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) gt_audio = resampler(gt_audio) total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) # # test vocoder resynthesis # ref_audio = gt_audio else: ref_text_len = len(prompt_text.encode("utf-8")) gen_text_len = len(gt_text.encode("utf-8")) total_mel_len = ref_mel_len + int( ref_mel_len / ref_text_len * gen_text_len / speed ) # to mel spectrogram ref_mel = mel_spectrogram(ref_audio) ref_mel = ref_mel.squeeze(0) # deal with batch assert infer_batch_size > 0, "infer_batch_size should be greater than 0." assert ( min_tokens <= total_mel_len <= max_tokens ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." bucket_i = math.floor( (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets ) utts[bucket_i].append(utt) ref_rms_list[bucket_i].append(ref_rms) ref_mels[bucket_i].append(ref_mel) ref_mel_lens[bucket_i].append(ref_mel_len) total_mel_lens[bucket_i].append(total_mel_len) final_text_list[bucket_i].extend(text_list) batch_accum[bucket_i] += total_mel_len if batch_accum[bucket_i] >= infer_batch_size: # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}") prompts_all.append( ( utts[bucket_i], ref_rms_list[bucket_i], padded_mel_batch(ref_mels[bucket_i]), ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i], ) ) batch_accum[bucket_i] = 0 ( utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i], ) = ( [], [], [], [], [], [], ) # add residual for bucket_i, bucket_frames in enumerate(batch_accum): if bucket_frames > 0: prompts_all.append( ( utts[bucket_i], ref_rms_list[bucket_i], padded_mel_batch(ref_mels[bucket_i]), ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i], ) ) # not only leave easy work for last workers random.seed(666) random.shuffle(prompts_all) return prompts_all def padded_mel_batch(ref_mels): max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() padded_ref_mels = [] for mel in ref_mels: padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) padded_ref_mels.append(padded_ref_mel) padded_ref_mels = torch.stack(padded_ref_mels) padded_ref_mels = padded_ref_mels.permute(0, 2, 1) return padded_ref_mels def get_seedtts_testset_metainfo(metalst): f = open(metalst) lines = f.readlines() f.close() metainfo = [] for line in lines: if len(line.strip().split("|")) == 5: utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") elif len(line.strip().split("|")) == 4: utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) return metainfo accelerator = Accelerator() device = f"cuda:{accelerator.process_index}" # --------------------- Dataset Settings -------------------- # target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 target_rms = 0.1 # rel_path = str(files("f5_tts").joinpath("../../")) def main(): # ---------------------- infer setting ---------------------- # parser = argparse.ArgumentParser(description="batch inference") parser.add_argument("-s", "--seed", default=None, type=int) parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN") parser.add_argument("-n", "--expname", required=True) parser.add_argument("-c", "--ckptstep", default=15000, type=int) parser.add_argument( "-m", "--mel_spec_type", default="bigvgan", type=str, choices=["bigvgan", "vocos"], ) parser.add_argument( "-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"] ) parser.add_argument("-nfe", "--nfestep", default=32, type=int) parser.add_argument("-o", "--odemethod", default="euler") parser.add_argument("-ss", "--swaysampling", default=-1, type=float) parser.add_argument("-t", "--testset", required=True) args = parser.parse_args() seed = args.seed dataset_name = args.dataset exp_name = args.expname ckpt_step = args.ckptstep ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt" ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt" mel_spec_type = args.mel_spec_type tokenizer = args.tokenizer nfe_step = args.nfestep ode_method = args.odemethod sway_sampling_coef = args.swaysampling testset = args.testset infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) cfg_strength = 2.0 speed = 1.0 use_truth_duration = False no_ref_audio = False model_cls = DiT model_cfg = dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 ) metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst" metainfo = get_seedtts_testset_metainfo(metalst) # path to save genereted wavs output_dir = ( f"./" f"results/{exp_name}_{ckpt_step}/{testset}/" f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}" f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" f"_cfg{cfg_strength}_speed{speed}" f"{'_gt-dur' if use_truth_duration else ''}" f"{'_no-ref-audio' if no_ref_audio else ''}" ) prompts_all = get_inference_prompt( metainfo, speed=speed, tokenizer=tokenizer, target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length, mel_spec_type=mel_spec_type, target_rms=target_rms, use_truth_duration=use_truth_duration, infer_batch_size=infer_batch_size, ) vocoder = load_vocoder(device) # Tokenizer vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt") # Model model = CFM( transformer=model_cls( **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels ), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ), odeint_kwargs=dict( method=ode_method, ), vocab_char_map=vocab_char_map, ).to(device) dtype = torch.float32 if mel_spec_type == "bigvgan" else None # model = load_pretrained_checkpoint(model, ckpt_path) _ = load_checkpoint( ckpt_path, model=model, ) model = model.eval().to(device) if not os.path.exists(output_dir) and accelerator.is_main_process: os.makedirs(output_dir) # start batch inference accelerator.wait_for_everyone() start = time.time() with accelerator.split_between_processes(prompts_all) as prompts: for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): ( utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list, ) = prompt ref_mels = ref_mels.to(device) ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) # Inference with torch.inference_mode(): generated, _ = model.sample( cond=ref_mels, text=final_text_list, duration=total_mel_lens, lens=ref_mel_lens, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, no_ref_audio=no_ref_audio, seed=seed, ) # Final result for i, gen in enumerate(generated): gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) if mel_spec_type == "vocos": generated_wave = vocoder.decode(gen_mel_spec).cpu() elif mel_spec_type == "bigvgan": generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() if ref_rms_list[i] < target_rms: generated_wave = generated_wave * ref_rms_list[i] / target_rms torchaudio.save( f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate, ) accelerator.wait_for_everyone() if accelerator.is_main_process: timediff = time.time() - start print(f"Done batch inference in {timediff / 60 :.2f} minutes.") if __name__ == "__main__": main()