mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
add infer
This commit is contained in:
parent
511f63b551
commit
ec5cc5526e
411
egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py
Normal file
411
egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
# import bigvan
|
||||||
|
# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
# from importlib.resources import files
|
||||||
|
# import sys
|
||||||
|
# sys.path.append(f"/home/yuekaiz/BigVGAN/")
|
||||||
|
# from bigvgan import BigVGAN
|
||||||
|
from bigvganinference import BigVGANInference
|
||||||
|
|
||||||
|
# from f5_tts.eval.utils_eval import (
|
||||||
|
# get_inference_prompt,
|
||||||
|
# get_librispeech_test_clean_metainfo,
|
||||||
|
# get_seedtts_testset_metainfo,
|
||||||
|
# )
|
||||||
|
# from f5_tts.infer.utils_infer import load_vocoder
|
||||||
|
from model.cfm import CFM
|
||||||
|
from model.dit import DiT
|
||||||
|
from model.modules import MelSpec
|
||||||
|
from model.utils import convert_char_to_pinyin
|
||||||
|
from tqdm import tqdm
|
||||||
|
from train import get_tokenizer, load_pretrained_checkpoint
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocoder(device):
|
||||||
|
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
||||||
|
model = BigVGANInference.from_pretrained(
|
||||||
|
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
||||||
|
)
|
||||||
|
model = model.eval().to(device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_prompt(
|
||||||
|
metainfo,
|
||||||
|
speed=1.0,
|
||||||
|
tokenizer="pinyin",
|
||||||
|
polyphone=True,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
n_fft=1024,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
hop_length=256,
|
||||||
|
mel_spec_type="vocos",
|
||||||
|
target_rms=0.1,
|
||||||
|
use_truth_duration=False,
|
||||||
|
infer_batch_size=1,
|
||||||
|
num_buckets=200,
|
||||||
|
min_secs=3,
|
||||||
|
max_secs=40,
|
||||||
|
):
|
||||||
|
prompts_all = []
|
||||||
|
|
||||||
|
min_tokens = min_secs * target_sample_rate // hop_length
|
||||||
|
max_tokens = max_secs * target_sample_rate // hop_length
|
||||||
|
|
||||||
|
batch_accum = [0] * num_buckets
|
||||||
|
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||||
|
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_spectrogram = MelSpec(
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
n_mel_channels=n_mel_channels,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
mel_spec_type=mel_spec_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
|
||||||
|
metainfo, desc="Processing prompts..."
|
||||||
|
):
|
||||||
|
# Audio
|
||||||
|
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
||||||
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
||||||
|
if ref_rms < target_rms:
|
||||||
|
ref_audio = ref_audio * target_rms / ref_rms
|
||||||
|
assert (
|
||||||
|
ref_audio.shape[-1] > 5000
|
||||||
|
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio)
|
||||||
|
|
||||||
|
# Text
|
||||||
|
if len(prompt_text[-1].encode("utf-8")) == 1:
|
||||||
|
prompt_text = prompt_text + " "
|
||||||
|
text = [prompt_text + gt_text]
|
||||||
|
if tokenizer == "pinyin":
|
||||||
|
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
||||||
|
else:
|
||||||
|
text_list = text
|
||||||
|
|
||||||
|
# Duration, mel frame length
|
||||||
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||||
|
if use_truth_duration:
|
||||||
|
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||||
|
if gt_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
||||||
|
gt_audio = resampler(gt_audio)
|
||||||
|
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
||||||
|
|
||||||
|
# # test vocoder resynthesis
|
||||||
|
# ref_audio = gt_audio
|
||||||
|
else:
|
||||||
|
ref_text_len = len(prompt_text.encode("utf-8"))
|
||||||
|
gen_text_len = len(gt_text.encode("utf-8"))
|
||||||
|
total_mel_len = ref_mel_len + int(
|
||||||
|
ref_mel_len / ref_text_len * gen_text_len / speed
|
||||||
|
)
|
||||||
|
|
||||||
|
# to mel spectrogram
|
||||||
|
ref_mel = mel_spectrogram(ref_audio)
|
||||||
|
ref_mel = ref_mel.squeeze(0)
|
||||||
|
|
||||||
|
# deal with batch
|
||||||
|
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||||
|
assert (
|
||||||
|
min_tokens <= total_mel_len <= max_tokens
|
||||||
|
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||||
|
bucket_i = math.floor(
|
||||||
|
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
||||||
|
)
|
||||||
|
|
||||||
|
utts[bucket_i].append(utt)
|
||||||
|
ref_rms_list[bucket_i].append(ref_rms)
|
||||||
|
ref_mels[bucket_i].append(ref_mel)
|
||||||
|
ref_mel_lens[bucket_i].append(ref_mel_len)
|
||||||
|
total_mel_lens[bucket_i].append(total_mel_len)
|
||||||
|
final_text_list[bucket_i].extend(text_list)
|
||||||
|
|
||||||
|
batch_accum[bucket_i] += total_mel_len
|
||||||
|
|
||||||
|
if batch_accum[bucket_i] >= infer_batch_size:
|
||||||
|
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
||||||
|
prompts_all.append(
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
padded_mel_batch(ref_mels[bucket_i]),
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch_accum[bucket_i] = 0
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
ref_mels[bucket_i],
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# add residual
|
||||||
|
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||||
|
if bucket_frames > 0:
|
||||||
|
prompts_all.append(
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
padded_mel_batch(ref_mels[bucket_i]),
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# not only leave easy work for last workers
|
||||||
|
random.seed(666)
|
||||||
|
random.shuffle(prompts_all)
|
||||||
|
|
||||||
|
return prompts_all
|
||||||
|
|
||||||
|
|
||||||
|
def padded_mel_batch(ref_mels):
|
||||||
|
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||||
|
padded_ref_mels = []
|
||||||
|
for mel in ref_mels:
|
||||||
|
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
||||||
|
padded_ref_mels.append(padded_ref_mel)
|
||||||
|
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||||
|
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
||||||
|
return padded_ref_mels
|
||||||
|
|
||||||
|
|
||||||
|
def get_seedtts_testset_metainfo(metalst):
|
||||||
|
f = open(metalst)
|
||||||
|
lines = f.readlines()
|
||||||
|
f.close()
|
||||||
|
metainfo = []
|
||||||
|
for line in lines:
|
||||||
|
if len(line.strip().split("|")) == 5:
|
||||||
|
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
||||||
|
elif len(line.strip().split("|")) == 4:
|
||||||
|
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||||
|
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
||||||
|
if not os.path.isabs(prompt_wav):
|
||||||
|
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||||
|
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
||||||
|
return metainfo
|
||||||
|
|
||||||
|
|
||||||
|
accelerator = Accelerator()
|
||||||
|
device = f"cuda:{accelerator.process_index}"
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------- Dataset Settings -------------------- #
|
||||||
|
|
||||||
|
target_sample_rate = 24000
|
||||||
|
n_mel_channels = 100
|
||||||
|
hop_length = 256
|
||||||
|
win_length = 1024
|
||||||
|
n_fft = 1024
|
||||||
|
target_rms = 0.1
|
||||||
|
|
||||||
|
# rel_path = str(files("f5_tts").joinpath("../../"))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# ---------------------- infer setting ---------------------- #
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="batch inference")
|
||||||
|
|
||||||
|
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||||
|
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
||||||
|
parser.add_argument("-n", "--expname", required=True)
|
||||||
|
parser.add_argument("-c", "--ckptstep", default=15000, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--mel_spec_type",
|
||||||
|
default="bigvgan",
|
||||||
|
type=str,
|
||||||
|
choices=["bigvgan", "vocos"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||||
|
parser.add_argument("-o", "--odemethod", default="euler")
|
||||||
|
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||||
|
|
||||||
|
parser.add_argument("-t", "--testset", required=True)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
seed = args.seed
|
||||||
|
dataset_name = args.dataset
|
||||||
|
exp_name = args.expname
|
||||||
|
ckpt_step = args.ckptstep
|
||||||
|
|
||||||
|
ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"
|
||||||
|
ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt"
|
||||||
|
|
||||||
|
mel_spec_type = args.mel_spec_type
|
||||||
|
tokenizer = args.tokenizer
|
||||||
|
|
||||||
|
nfe_step = args.nfestep
|
||||||
|
ode_method = args.odemethod
|
||||||
|
sway_sampling_coef = args.swaysampling
|
||||||
|
|
||||||
|
testset = args.testset
|
||||||
|
|
||||||
|
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
||||||
|
cfg_strength = 2.0
|
||||||
|
speed = 1.0
|
||||||
|
use_truth_duration = False
|
||||||
|
no_ref_audio = False
|
||||||
|
|
||||||
|
model_cls = DiT
|
||||||
|
model_cfg = dict(
|
||||||
|
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
||||||
|
)
|
||||||
|
metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst"
|
||||||
|
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||||
|
|
||||||
|
# path to save genereted wavs
|
||||||
|
output_dir = (
|
||||||
|
f"./"
|
||||||
|
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||||
|
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
|
||||||
|
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
||||||
|
f"_cfg{cfg_strength}_speed{speed}"
|
||||||
|
f"{'_gt-dur' if use_truth_duration else ''}"
|
||||||
|
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts_all = get_inference_prompt(
|
||||||
|
metainfo,
|
||||||
|
speed=speed,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
n_mel_channels=n_mel_channels,
|
||||||
|
hop_length=hop_length,
|
||||||
|
mel_spec_type=mel_spec_type,
|
||||||
|
target_rms=target_rms,
|
||||||
|
use_truth_duration=use_truth_duration,
|
||||||
|
infer_batch_size=infer_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
vocoder = load_vocoder(device)
|
||||||
|
|
||||||
|
# Tokenizer
|
||||||
|
vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt")
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model = CFM(
|
||||||
|
transformer=model_cls(
|
||||||
|
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
||||||
|
),
|
||||||
|
mel_spec_kwargs=dict(
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
n_mel_channels=n_mel_channels,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
mel_spec_type=mel_spec_type,
|
||||||
|
),
|
||||||
|
odeint_kwargs=dict(
|
||||||
|
method=ode_method,
|
||||||
|
),
|
||||||
|
vocab_char_map=vocab_char_map,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||||
|
# model = load_pretrained_checkpoint(model, ckpt_path)
|
||||||
|
_ = load_checkpoint(
|
||||||
|
ckpt_path,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
model = model.eval().to(device)
|
||||||
|
|
||||||
|
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
# start batch inference
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
with accelerator.split_between_processes(prompts_all) as prompts:
|
||||||
|
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
||||||
|
(
|
||||||
|
utts,
|
||||||
|
ref_rms_list,
|
||||||
|
ref_mels,
|
||||||
|
ref_mel_lens,
|
||||||
|
total_mel_lens,
|
||||||
|
final_text_list,
|
||||||
|
) = prompt
|
||||||
|
ref_mels = ref_mels.to(device)
|
||||||
|
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||||
|
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
with torch.inference_mode():
|
||||||
|
generated, _ = model.sample(
|
||||||
|
cond=ref_mels,
|
||||||
|
text=final_text_list,
|
||||||
|
duration=total_mel_lens,
|
||||||
|
lens=ref_mel_lens,
|
||||||
|
steps=nfe_step,
|
||||||
|
cfg_strength=cfg_strength,
|
||||||
|
sway_sampling_coef=sway_sampling_coef,
|
||||||
|
no_ref_audio=no_ref_audio,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
# Final result
|
||||||
|
for i, gen in enumerate(generated):
|
||||||
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||||
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||||
|
if mel_spec_type == "vocos":
|
||||||
|
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||||
|
elif mel_spec_type == "bigvgan":
|
||||||
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||||
|
|
||||||
|
if ref_rms_list[i] < target_rms:
|
||||||
|
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||||
|
torchaudio.save(
|
||||||
|
f"{output_dir}/{utts[i]}.wav",
|
||||||
|
generated_wave,
|
||||||
|
target_sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
timediff = time.time() - start
|
||||||
|
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/wenetspeech4tts/TTS/f5-tts/optim.py
Symbolic link
1
egs/wenetspeech4tts/TTS/f5-tts/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/optim.py
|
104
egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py
Normal file
104
egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
from typing import Callable, Dict, List, Sequence, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import validate
|
||||||
|
from lhotse.cut import CutSet
|
||||||
|
from lhotse.dataset.collation import collate_audio
|
||||||
|
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||||
|
from lhotse.utils import ifnone
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechSynthesisDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
The PyTorch Dataset for the speech synthesis task.
|
||||||
|
Each item in this dataset is a dict of:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
{
|
||||||
|
'audio': (B x NumSamples) float tensor
|
||||||
|
'features': (B x NumFrames x NumFeatures) float tensor
|
||||||
|
'audio_lens': (B, ) int tensor
|
||||||
|
'features_lens': (B, ) int tensor
|
||||||
|
'text': List[str] of len B # when return_text=True
|
||||||
|
'tokens': List[List[str]] # when return_tokens=True
|
||||||
|
'speakers': List[str] of len B # when return_spk_ids=True
|
||||||
|
'cut': List of Cuts # when return_cuts=True
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||||
|
feature_input_strategy: BatchIO = PrecomputedFeatures(),
|
||||||
|
feature_transforms: Union[Sequence[Callable], Callable] = None,
|
||||||
|
return_text: bool = True,
|
||||||
|
return_tokens: bool = False,
|
||||||
|
return_spk_ids: bool = False,
|
||||||
|
return_cuts: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cut_transforms = ifnone(cut_transforms, [])
|
||||||
|
self.feature_input_strategy = feature_input_strategy
|
||||||
|
|
||||||
|
self.return_text = return_text
|
||||||
|
self.return_tokens = return_tokens
|
||||||
|
self.return_spk_ids = return_spk_ids
|
||||||
|
self.return_cuts = return_cuts
|
||||||
|
|
||||||
|
if feature_transforms is None:
|
||||||
|
feature_transforms = []
|
||||||
|
elif not isinstance(feature_transforms, Sequence):
|
||||||
|
feature_transforms = [feature_transforms]
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
isinstance(transform, Callable) for transform in feature_transforms
|
||||||
|
), "Feature transforms must be Callable"
|
||||||
|
self.feature_transforms = feature_transforms
|
||||||
|
|
||||||
|
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
|
||||||
|
validate_for_tts(cuts)
|
||||||
|
|
||||||
|
for transform in self.cut_transforms:
|
||||||
|
cuts = transform(cuts)
|
||||||
|
|
||||||
|
# audio, audio_lens = collate_audio(cuts)
|
||||||
|
features, features_lens = self.feature_input_strategy(cuts)
|
||||||
|
|
||||||
|
for transform in self.feature_transforms:
|
||||||
|
features = transform(features)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
# "audio": audio,
|
||||||
|
"features": features,
|
||||||
|
# "audio_lens": audio_lens,
|
||||||
|
"features_lens": features_lens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.return_text:
|
||||||
|
# use normalized text
|
||||||
|
# text = [cut.supervisions[0].normalized_text for cut in cuts]
|
||||||
|
text = [cut.supervisions[0].text for cut in cuts]
|
||||||
|
batch["text"] = text
|
||||||
|
|
||||||
|
if self.return_tokens:
|
||||||
|
# tokens = [cut.tokens for cut in cuts]
|
||||||
|
tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
|
||||||
|
batch["tokens"] = tokens
|
||||||
|
|
||||||
|
if self.return_spk_ids:
|
||||||
|
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
|
||||||
|
|
||||||
|
if self.return_cuts:
|
||||||
|
batch["cut"] = [cut for cut in cuts]
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def validate_for_tts(cuts: CutSet) -> None:
|
||||||
|
validate(cuts)
|
||||||
|
for cut in cuts:
|
||||||
|
assert (
|
||||||
|
len(cut.supervisions) == 1
|
||||||
|
), "Only the Cuts with single supervision are supported."
|
@ -47,10 +47,13 @@ from model.dit import DiT
|
|||||||
from model.utils import convert_char_to_pinyin
|
from model.utils import convert_char_to_pinyin
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
|
# from torch.cuda.amp import GradScaler
|
||||||
|
from torch.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tts_datamodule import TtsDataModule
|
from tts_datamodule import TtsDataModule
|
||||||
|
from utils import MetricsTracker
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
@ -61,7 +64,7 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, setup_logger, str2bool # MetricsTracker
|
||||||
|
|
||||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||||
|
|
||||||
@ -340,7 +343,7 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 1,
|
"log_interval": 100,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 10000,
|
"valid_interval": 10000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -411,12 +414,12 @@ def load_pretrained_checkpoint(
|
|||||||
):
|
):
|
||||||
# model = model.to(dtype)
|
# model = model.to(dtype)
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||||
|
if "ema_model_state_dict" in checkpoint:
|
||||||
checkpoint["model_state_dict"] = {
|
checkpoint["model_state_dict"] = {
|
||||||
k.replace("ema_model.", ""): v
|
k.replace("ema_model.", ""): v
|
||||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||||
if k not in ["initted", "step"]
|
if k not in ["initted", "step"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# patch for backward compatibility, 305e3ea
|
# patch for backward compatibility, 305e3ea
|
||||||
for key in [
|
for key in [
|
||||||
@ -553,7 +556,7 @@ def prepare_input(batch: dict, device: torch.device):
|
|||||||
text_inputs = batch["text"]
|
text_inputs = batch["text"]
|
||||||
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
|
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
|
||||||
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||||
print(text_inputs)
|
|
||||||
mel_spec = batch["features"]
|
mel_spec = batch["features"]
|
||||||
mel_lengths = batch["features_lens"]
|
mel_lengths = batch["features_lens"]
|
||||||
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
||||||
@ -591,22 +594,13 @@ def compute_loss(
|
|||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
|
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
print(loss)
|
|
||||||
# from accelerate import Accelerator
|
|
||||||
# from accelerate.utils import DistributedDataParallelKwargs
|
|
||||||
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
||||||
# accelerator = Accelerator(
|
|
||||||
# kwargs_handlers=[ddp_kwargs],
|
|
||||||
# )
|
|
||||||
# accelerator.backward(loss)
|
|
||||||
# loss.backward()
|
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
# with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
# warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
# info["samples"] = mel_lengths.size(0)
|
info["samples"] = mel_lengths.size(0)
|
||||||
|
|
||||||
# info["loss"] = loss.detach().cpu().item() * info["samples"]
|
info["loss"] = loss.detach().cpu().item() * info["samples"]
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -633,7 +627,7 @@ def compute_validation_loss(
|
|||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["samples"]
|
||||||
if loss_value < params.best_valid_loss:
|
if loss_value < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = loss_value
|
params.best_valid_loss = loss_value
|
||||||
@ -721,7 +715,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):
|
with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -749,7 +743,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
# optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# loss.backward()
|
# loss.backward()
|
||||||
# optimizer.step()
|
# optimizer.step()
|
||||||
|
|
||||||
@ -856,7 +850,7 @@ def train_one_epoch(
|
|||||||
# Calculate validation loss in Rank 0
|
# Calculate validation loss in Rank 0
|
||||||
model.eval()
|
model.eval()
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
with torch.cuda.amp.autocast(dtype=dtype):
|
with torch.amp.autocast("cuda", dtype=dtype):
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -876,7 +870,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["samples"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
@ -944,7 +938,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
||||||
@ -969,7 +962,7 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=False)
|
||||||
|
|
||||||
model_parameters = model.parameters()
|
model_parameters = model.parameters()
|
||||||
|
|
||||||
@ -1046,7 +1039,9 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
|
scaler = GradScaler(
|
||||||
|
"cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
|
||||||
|
)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1141,7 +1136,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
print(batch.keys())
|
print(batch.keys())
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(dtype=dtype):
|
with torch.amp.autocast("cuda", dtype=dtype):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
1
egs/wenetspeech4tts/TTS/f5-tts/utils.py
Symbolic link
1
egs/wenetspeech4tts/TTS/f5-tts/utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/matcha/utils.py
|
3
egs/wenetspeech4tts/TTS/infer_f5.sh
Normal file
3
egs/wenetspeech4tts/TTS/infer_f5.sh
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||||
|
|
||||||
|
accelerate launch f5-tts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
28
egs/wenetspeech4tts/TTS/train_f5.sh
Normal file
28
egs/wenetspeech4tts/TTS/train_f5.sh
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||||
|
|
||||||
|
install_flag=false
|
||||||
|
if [ "$install_flag" = true ]; then
|
||||||
|
echo "Installing packages..."
|
||||||
|
|
||||||
|
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
# lhotse tensorboard kaldialign
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install phonemizer pypinyin sentencepiece kaldialign matplotlib h5py
|
||||||
|
|
||||||
|
apt-get update && apt-get -y install festival espeak-ng mbrola
|
||||||
|
else
|
||||||
|
echo "Skipping installation."
|
||||||
|
fi
|
||||||
|
|
||||||
|
world_size=8
|
||||||
|
#world_size=1
|
||||||
|
|
||||||
|
exp_dir=exp/f5
|
||||||
|
|
||||||
|
# pip install -r f5-tts/requirements.txt
|
||||||
|
python3 f5-tts/train.py --max-duration 300 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||||
|
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 8000 \
|
||||||
|
--base-lr 1e-4 --warmup-steps 5000 --average-period 200 \
|
||||||
|
--num-epochs 10 --start-epoch 1 --start-batch 20000 \
|
||||||
|
--exp-dir ${exp_dir} --world-size ${world_size}
|
Loading…
x
Reference in New Issue
Block a user