add infer

This commit is contained in:
yuekaiz 2024-12-23 17:35:36 +08:00
parent 511f63b551
commit ec5cc5526e
7 changed files with 575 additions and 32 deletions

View File

@ -0,0 +1,411 @@
import argparse
import math
import os
import random
import time
# import bigvan
# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
import torch
import torch.nn.functional as F
import torchaudio
from accelerate import Accelerator
# from importlib.resources import files
# import sys
# sys.path.append(f"/home/yuekaiz/BigVGAN/")
# from bigvgan import BigVGAN
from bigvganinference import BigVGANInference
# from f5_tts.eval.utils_eval import (
# get_inference_prompt,
# get_librispeech_test_clean_metainfo,
# get_seedtts_testset_metainfo,
# )
# from f5_tts.infer.utils_infer import load_vocoder
from model.cfm import CFM
from model.dit import DiT
from model.modules import MelSpec
from model.utils import convert_char_to_pinyin
from tqdm import tqdm
from train import get_tokenizer, load_pretrained_checkpoint
from icefall.checkpoint import load_checkpoint
def load_vocoder(device):
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
model = BigVGANInference.from_pretrained(
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
)
model = model.eval().to(device)
return model
def get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="vocos",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
):
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
metainfo, desc="Processing prompts..."
):
# Audio
ref_audio, ref_sr = torchaudio.load(prompt_wav)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
if ref_rms < target_rms:
ref_audio = ref_audio * target_rms / ref_rms
assert (
ref_audio.shape[-1] > 5000
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio)
# Text
if len(prompt_text[-1].encode("utf-8")) == 1:
prompt_text = prompt_text + " "
text = [prompt_text + gt_text]
if tokenizer == "pinyin":
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
else:
text_list = text
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
gt_audio = resampler(gt_audio)
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
ref_text_len = len(prompt_text.encode("utf-8"))
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(
ref_mel_len / ref_text_len * gen_text_len / speed
)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor(
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
final_text_list[bucket_i].extend(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = (
[],
[],
[],
[],
[],
[],
)
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
return padded_ref_mels
def get_seedtts_testset_metainfo(metalst):
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
if len(line.strip().split("|")) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
elif len(line.strip().split("|")) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
return metainfo
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
target_rms = 0.1
# rel_path = str(files("f5_tts").joinpath("../../"))
def main():
# ---------------------- infer setting ---------------------- #
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=15000, type=int)
parser.add_argument(
"-m",
"--mel_spec_type",
default="bigvgan",
type=str,
choices=["bigvgan", "vocos"],
)
parser.add_argument(
"-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]
)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument("-t", "--testset", required=True)
args = parser.parse_args()
seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"
ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt"
mel_spec_type = args.mel_spec_type
tokenizer = args.tokenizer
nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling
testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.0
speed = 1.0
use_truth_duration = False
no_ref_audio = False
model_cls = DiT
model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
# path to save genereted wavs
output_dir = (
f"./"
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
f"{'_no-ref-audio' if no_ref_audio else ''}"
)
prompts_all = get_inference_prompt(
metainfo,
speed=speed,
tokenizer=tokenizer,
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
mel_spec_type=mel_spec_type,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
)
vocoder = load_vocoder(device)
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt")
# Model
model = CFM(
transformer=model_cls(
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
# model = load_pretrained_checkpoint(model, ckpt_path)
_ = load_checkpoint(
ckpt_path,
model=model,
)
model = model.eval().to(device)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
# start batch inference
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
(
utts,
ref_rms_list,
ref_mels,
ref_mel_lens,
total_mel_lens,
final_text_list,
) = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec).cpu()
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(
f"{output_dir}/{utts[i]}.wav",
generated_wave,
target_sample_rate,
)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1,104 @@
from typing import Callable, Dict, List, Sequence, Union
import torch
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.collation import collate_audio
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import ifnone
class SpeechSynthesisDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech synthesis task.
Each item in this dataset is a dict of:
.. code-block::
{
'audio': (B x NumSamples) float tensor
'features': (B x NumFrames x NumFeatures) float tensor
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'text': List[str] of len B # when return_text=True
'tokens': List[List[str]] # when return_tokens=True
'speakers': List[str] of len B # when return_spk_ids=True
'cut': List of Cuts # when return_cuts=True
}
"""
def __init__(
self,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
return_text: bool = True,
return_tokens: bool = False,
return_spk_ids: bool = False,
return_cuts: bool = False,
) -> None:
super().__init__()
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy
self.return_text = return_text
self.return_tokens = return_tokens
self.return_spk_ids = return_spk_ids
self.return_cuts = return_cuts
if feature_transforms is None:
feature_transforms = []
elif not isinstance(feature_transforms, Sequence):
feature_transforms = [feature_transforms]
assert all(
isinstance(transform, Callable) for transform in feature_transforms
), "Feature transforms must be Callable"
self.feature_transforms = feature_transforms
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
validate_for_tts(cuts)
for transform in self.cut_transforms:
cuts = transform(cuts)
# audio, audio_lens = collate_audio(cuts)
features, features_lens = self.feature_input_strategy(cuts)
for transform in self.feature_transforms:
features = transform(features)
batch = {
# "audio": audio,
"features": features,
# "audio_lens": audio_lens,
"features_lens": features_lens,
}
if self.return_text:
# use normalized text
# text = [cut.supervisions[0].normalized_text for cut in cuts]
text = [cut.supervisions[0].text for cut in cuts]
batch["text"] = text
if self.return_tokens:
# tokens = [cut.tokens for cut in cuts]
tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
batch["tokens"] = tokens
if self.return_spk_ids:
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
if self.return_cuts:
batch["cut"] = [cut for cut in cuts]
return batch
def validate_for_tts(cuts: CutSet) -> None:
validate(cuts)
for cut in cuts:
assert (
len(cut.supervisions) == 1
), "Only the Cuts with single supervision are supported."

View File

@ -47,10 +47,13 @@ from model.dit import DiT
from model.utils import convert_char_to_pinyin
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
# from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import TtsDataModule
from utils import MetricsTracker
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
@ -61,7 +64,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import AttributeDict, setup_logger, str2bool # MetricsTracker
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@ -340,7 +343,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 1,
"log_interval": 100,
"reset_interval": 200,
"valid_interval": 10000,
"env_info": get_env_info(),
@ -411,12 +414,12 @@ def load_pretrained_checkpoint(
):
# model = model.to(dtype)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
if "ema_model_state_dict" in checkpoint:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
# patch for backward compatibility, 305e3ea
for key in [
@ -553,7 +556,7 @@ def prepare_input(batch: dict, device: torch.device):
text_inputs = batch["text"]
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
print(text_inputs)
mel_spec = batch["features"]
mel_lengths = batch["features_lens"]
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
@ -591,22 +594,13 @@ def compute_loss(
with torch.set_grad_enabled(is_training):
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
assert loss.requires_grad == is_training
print(loss)
# from accelerate import Accelerator
# from accelerate.utils import DistributedDataParallelKwargs
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# accelerator = Accelerator(
# kwargs_handlers=[ddp_kwargs],
# )
# accelerator.backward(loss)
# loss.backward()
info = MetricsTracker()
# with warnings.catch_warnings():
# warnings.simplefilter("ignore")
# info["samples"] = mel_lengths.size(0)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["samples"] = mel_lengths.size(0)
# info["loss"] = loss.detach().cpu().item() * info["samples"]
info["loss"] = loss.detach().cpu().item() * info["samples"]
return loss, info
@ -633,7 +627,7 @@ def compute_validation_loss(
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
loss_value = tot_loss["loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
@ -721,7 +715,7 @@ def train_one_epoch(
batch_size = len(batch["text"])
try:
with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):
with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -749,7 +743,7 @@ def train_one_epoch(
scaler.step(optimizer)
scaler.update()
# optimizer.zero_grad()
optimizer.zero_grad()
# loss.backward()
# optimizer.step()
@ -856,7 +850,7 @@ def train_one_epoch(
# Calculate validation loss in Rank 0
model.eval()
logging.info("Computing validation loss")
with torch.cuda.amp.autocast(dtype=dtype):
with torch.amp.autocast("cuda", dtype=dtype):
valid_info = compute_validation_loss(
params=params,
model=model,
@ -876,7 +870,7 @@ def train_one_epoch(
model.train()
loss_value = tot_loss["loss"] / tot_loss["frames"]
loss_value = tot_loss["loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
@ -944,7 +938,6 @@ def run(rank, world_size, args):
model = get_model(params)
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
model = model.to(device)
with open(f"{params.exp_dir}/model.txt", "w") as f:
@ -969,7 +962,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=False)
model_parameters = model.parameters()
@ -1046,7 +1039,9 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
scaler = GradScaler(
"cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1141,7 +1136,7 @@ def scan_pessimistic_batches_for_oom(
batch = train_dl.dataset[cuts]
print(batch.keys())
try:
with torch.cuda.amp.autocast(dtype=dtype):
with torch.amp.autocast("cuda", dtype=dtype):
loss, loss_info = compute_loss(
params=params,
model=model,

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/utils.py

View File

@ -0,0 +1,3 @@
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
accelerate launch f5-tts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16

View File

@ -0,0 +1,28 @@
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
install_flag=false
if [ "$install_flag" = true ]; then
echo "Installing packages..."
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# lhotse tensorboard kaldialign
pip install -r requirements.txt
pip install phonemizer pypinyin sentencepiece kaldialign matplotlib h5py
apt-get update && apt-get -y install festival espeak-ng mbrola
else
echo "Skipping installation."
fi
world_size=8
#world_size=1
exp_dir=exp/f5
# pip install -r f5-tts/requirements.txt
python3 f5-tts/train.py --max-duration 300 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 8000 \
--base-lr 1e-4 --warmup-steps 5000 --average-period 200 \
--num-epochs 10 --start-epoch 1 --start-batch 20000 \
--exp-dir ${exp_dir} --world-size ${world_size}