mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
412 lines
13 KiB
Python
412 lines
13 KiB
Python
import argparse
|
|
import math
|
|
import os
|
|
import random
|
|
import time
|
|
|
|
# import bigvan
|
|
# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
from accelerate import Accelerator
|
|
|
|
# from importlib.resources import files
|
|
# import sys
|
|
# sys.path.append(f"/home/yuekaiz/BigVGAN/")
|
|
# from bigvgan import BigVGAN
|
|
from bigvganinference import BigVGANInference
|
|
|
|
# from f5_tts.eval.utils_eval import (
|
|
# get_inference_prompt,
|
|
# get_librispeech_test_clean_metainfo,
|
|
# get_seedtts_testset_metainfo,
|
|
# )
|
|
# from f5_tts.infer.utils_infer import load_vocoder
|
|
from model.cfm import CFM
|
|
from model.dit import DiT
|
|
from model.modules import MelSpec
|
|
from model.utils import convert_char_to_pinyin
|
|
from tqdm import tqdm
|
|
from train import get_tokenizer, load_pretrained_checkpoint
|
|
|
|
from icefall.checkpoint import load_checkpoint
|
|
|
|
|
|
def load_vocoder(device):
|
|
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
|
model = BigVGANInference.from_pretrained(
|
|
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
|
)
|
|
model = model.eval().to(device)
|
|
return model
|
|
|
|
|
|
def get_inference_prompt(
|
|
metainfo,
|
|
speed=1.0,
|
|
tokenizer="pinyin",
|
|
polyphone=True,
|
|
target_sample_rate=24000,
|
|
n_fft=1024,
|
|
win_length=1024,
|
|
n_mel_channels=100,
|
|
hop_length=256,
|
|
mel_spec_type="vocos",
|
|
target_rms=0.1,
|
|
use_truth_duration=False,
|
|
infer_batch_size=1,
|
|
num_buckets=200,
|
|
min_secs=3,
|
|
max_secs=40,
|
|
):
|
|
prompts_all = []
|
|
|
|
min_tokens = min_secs * target_sample_rate // hop_length
|
|
max_tokens = max_secs * target_sample_rate // hop_length
|
|
|
|
batch_accum = [0] * num_buckets
|
|
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
|
[[] for _ in range(num_buckets)] for _ in range(6)
|
|
)
|
|
|
|
mel_spectrogram = MelSpec(
|
|
n_fft=n_fft,
|
|
hop_length=hop_length,
|
|
win_length=win_length,
|
|
n_mel_channels=n_mel_channels,
|
|
target_sample_rate=target_sample_rate,
|
|
mel_spec_type=mel_spec_type,
|
|
)
|
|
|
|
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
|
|
metainfo, desc="Processing prompts..."
|
|
):
|
|
# Audio
|
|
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
|
if ref_rms < target_rms:
|
|
ref_audio = ref_audio * target_rms / ref_rms
|
|
assert (
|
|
ref_audio.shape[-1] > 5000
|
|
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
|
if ref_sr != target_sample_rate:
|
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
|
ref_audio = resampler(ref_audio)
|
|
|
|
# Text
|
|
if len(prompt_text[-1].encode("utf-8")) == 1:
|
|
prompt_text = prompt_text + " "
|
|
text = [prompt_text + gt_text]
|
|
if tokenizer == "pinyin":
|
|
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
|
else:
|
|
text_list = text
|
|
|
|
# Duration, mel frame length
|
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
|
if use_truth_duration:
|
|
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
|
if gt_sr != target_sample_rate:
|
|
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
|
gt_audio = resampler(gt_audio)
|
|
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
|
|
|
# # test vocoder resynthesis
|
|
# ref_audio = gt_audio
|
|
else:
|
|
ref_text_len = len(prompt_text.encode("utf-8"))
|
|
gen_text_len = len(gt_text.encode("utf-8"))
|
|
total_mel_len = ref_mel_len + int(
|
|
ref_mel_len / ref_text_len * gen_text_len / speed
|
|
)
|
|
|
|
# to mel spectrogram
|
|
ref_mel = mel_spectrogram(ref_audio)
|
|
ref_mel = ref_mel.squeeze(0)
|
|
|
|
# deal with batch
|
|
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
|
assert (
|
|
min_tokens <= total_mel_len <= max_tokens
|
|
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
|
bucket_i = math.floor(
|
|
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
|
)
|
|
|
|
utts[bucket_i].append(utt)
|
|
ref_rms_list[bucket_i].append(ref_rms)
|
|
ref_mels[bucket_i].append(ref_mel)
|
|
ref_mel_lens[bucket_i].append(ref_mel_len)
|
|
total_mel_lens[bucket_i].append(total_mel_len)
|
|
final_text_list[bucket_i].extend(text_list)
|
|
|
|
batch_accum[bucket_i] += total_mel_len
|
|
|
|
if batch_accum[bucket_i] >= infer_batch_size:
|
|
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
|
prompts_all.append(
|
|
(
|
|
utts[bucket_i],
|
|
ref_rms_list[bucket_i],
|
|
padded_mel_batch(ref_mels[bucket_i]),
|
|
ref_mel_lens[bucket_i],
|
|
total_mel_lens[bucket_i],
|
|
final_text_list[bucket_i],
|
|
)
|
|
)
|
|
batch_accum[bucket_i] = 0
|
|
(
|
|
utts[bucket_i],
|
|
ref_rms_list[bucket_i],
|
|
ref_mels[bucket_i],
|
|
ref_mel_lens[bucket_i],
|
|
total_mel_lens[bucket_i],
|
|
final_text_list[bucket_i],
|
|
) = (
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
)
|
|
|
|
# add residual
|
|
for bucket_i, bucket_frames in enumerate(batch_accum):
|
|
if bucket_frames > 0:
|
|
prompts_all.append(
|
|
(
|
|
utts[bucket_i],
|
|
ref_rms_list[bucket_i],
|
|
padded_mel_batch(ref_mels[bucket_i]),
|
|
ref_mel_lens[bucket_i],
|
|
total_mel_lens[bucket_i],
|
|
final_text_list[bucket_i],
|
|
)
|
|
)
|
|
# not only leave easy work for last workers
|
|
random.seed(666)
|
|
random.shuffle(prompts_all)
|
|
|
|
return prompts_all
|
|
|
|
|
|
def padded_mel_batch(ref_mels):
|
|
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
|
padded_ref_mels = []
|
|
for mel in ref_mels:
|
|
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
|
padded_ref_mels.append(padded_ref_mel)
|
|
padded_ref_mels = torch.stack(padded_ref_mels)
|
|
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
|
return padded_ref_mels
|
|
|
|
|
|
def get_seedtts_testset_metainfo(metalst):
|
|
f = open(metalst)
|
|
lines = f.readlines()
|
|
f.close()
|
|
metainfo = []
|
|
for line in lines:
|
|
if len(line.strip().split("|")) == 5:
|
|
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
|
elif len(line.strip().split("|")) == 4:
|
|
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
|
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
|
if not os.path.isabs(prompt_wav):
|
|
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
|
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
|
return metainfo
|
|
|
|
|
|
accelerator = Accelerator()
|
|
device = f"cuda:{accelerator.process_index}"
|
|
|
|
|
|
# --------------------- Dataset Settings -------------------- #
|
|
|
|
target_sample_rate = 24000
|
|
n_mel_channels = 100
|
|
hop_length = 256
|
|
win_length = 1024
|
|
n_fft = 1024
|
|
target_rms = 0.1
|
|
|
|
# rel_path = str(files("f5_tts").joinpath("../../"))
|
|
|
|
|
|
def main():
|
|
# ---------------------- infer setting ---------------------- #
|
|
|
|
parser = argparse.ArgumentParser(description="batch inference")
|
|
|
|
parser.add_argument("-s", "--seed", default=None, type=int)
|
|
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
|
parser.add_argument("-n", "--expname", required=True)
|
|
parser.add_argument("-c", "--ckptstep", default=15000, type=int)
|
|
parser.add_argument(
|
|
"-m",
|
|
"--mel_spec_type",
|
|
default="bigvgan",
|
|
type=str,
|
|
choices=["bigvgan", "vocos"],
|
|
)
|
|
parser.add_argument(
|
|
"-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]
|
|
)
|
|
|
|
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
|
parser.add_argument("-o", "--odemethod", default="euler")
|
|
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
|
|
|
parser.add_argument("-t", "--testset", required=True)
|
|
|
|
args = parser.parse_args()
|
|
|
|
seed = args.seed
|
|
dataset_name = args.dataset
|
|
exp_name = args.expname
|
|
ckpt_step = args.ckptstep
|
|
|
|
ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"
|
|
ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt"
|
|
|
|
mel_spec_type = args.mel_spec_type
|
|
tokenizer = args.tokenizer
|
|
|
|
nfe_step = args.nfestep
|
|
ode_method = args.odemethod
|
|
sway_sampling_coef = args.swaysampling
|
|
|
|
testset = args.testset
|
|
|
|
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
|
cfg_strength = 2.0
|
|
speed = 1.0
|
|
use_truth_duration = False
|
|
no_ref_audio = False
|
|
|
|
model_cls = DiT
|
|
model_cfg = dict(
|
|
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
|
)
|
|
metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst"
|
|
metainfo = get_seedtts_testset_metainfo(metalst)
|
|
|
|
# path to save genereted wavs
|
|
output_dir = (
|
|
f"./"
|
|
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
|
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
|
|
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
|
f"_cfg{cfg_strength}_speed{speed}"
|
|
f"{'_gt-dur' if use_truth_duration else ''}"
|
|
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
|
)
|
|
|
|
prompts_all = get_inference_prompt(
|
|
metainfo,
|
|
speed=speed,
|
|
tokenizer=tokenizer,
|
|
target_sample_rate=target_sample_rate,
|
|
n_mel_channels=n_mel_channels,
|
|
hop_length=hop_length,
|
|
mel_spec_type=mel_spec_type,
|
|
target_rms=target_rms,
|
|
use_truth_duration=use_truth_duration,
|
|
infer_batch_size=infer_batch_size,
|
|
)
|
|
|
|
vocoder = load_vocoder(device)
|
|
|
|
# Tokenizer
|
|
vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt")
|
|
|
|
# Model
|
|
model = CFM(
|
|
transformer=model_cls(
|
|
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
|
),
|
|
mel_spec_kwargs=dict(
|
|
n_fft=n_fft,
|
|
hop_length=hop_length,
|
|
win_length=win_length,
|
|
n_mel_channels=n_mel_channels,
|
|
target_sample_rate=target_sample_rate,
|
|
mel_spec_type=mel_spec_type,
|
|
),
|
|
odeint_kwargs=dict(
|
|
method=ode_method,
|
|
),
|
|
vocab_char_map=vocab_char_map,
|
|
).to(device)
|
|
|
|
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
|
# model = load_pretrained_checkpoint(model, ckpt_path)
|
|
_ = load_checkpoint(
|
|
ckpt_path,
|
|
model=model,
|
|
)
|
|
model = model.eval().to(device)
|
|
|
|
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
|
os.makedirs(output_dir)
|
|
|
|
# start batch inference
|
|
accelerator.wait_for_everyone()
|
|
start = time.time()
|
|
|
|
with accelerator.split_between_processes(prompts_all) as prompts:
|
|
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
|
(
|
|
utts,
|
|
ref_rms_list,
|
|
ref_mels,
|
|
ref_mel_lens,
|
|
total_mel_lens,
|
|
final_text_list,
|
|
) = prompt
|
|
ref_mels = ref_mels.to(device)
|
|
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
|
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
|
|
|
# Inference
|
|
with torch.inference_mode():
|
|
generated, _ = model.sample(
|
|
cond=ref_mels,
|
|
text=final_text_list,
|
|
duration=total_mel_lens,
|
|
lens=ref_mel_lens,
|
|
steps=nfe_step,
|
|
cfg_strength=cfg_strength,
|
|
sway_sampling_coef=sway_sampling_coef,
|
|
no_ref_audio=no_ref_audio,
|
|
seed=seed,
|
|
)
|
|
# Final result
|
|
for i, gen in enumerate(generated):
|
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
|
if mel_spec_type == "vocos":
|
|
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
|
elif mel_spec_type == "bigvgan":
|
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
|
|
|
if ref_rms_list[i] < target_rms:
|
|
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
|
torchaudio.save(
|
|
f"{output_dir}/{utts[i]}.wav",
|
|
generated_wave,
|
|
target_sample_rate,
|
|
)
|
|
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
timediff = time.time() - start
|
|
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|