mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add infer
This commit is contained in:
parent
511f63b551
commit
ec5cc5526e
411
egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py
Normal file
411
egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py
Normal file
@ -0,0 +1,411 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
# import bigvan
|
||||
# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
|
||||
# from importlib.resources import files
|
||||
# import sys
|
||||
# sys.path.append(f"/home/yuekaiz/BigVGAN/")
|
||||
# from bigvgan import BigVGAN
|
||||
from bigvganinference import BigVGANInference
|
||||
|
||||
# from f5_tts.eval.utils_eval import (
|
||||
# get_inference_prompt,
|
||||
# get_librispeech_test_clean_metainfo,
|
||||
# get_seedtts_testset_metainfo,
|
||||
# )
|
||||
# from f5_tts.infer.utils_infer import load_vocoder
|
||||
from model.cfm import CFM
|
||||
from model.dit import DiT
|
||||
from model.modules import MelSpec
|
||||
from model.utils import convert_char_to_pinyin
|
||||
from tqdm import tqdm
|
||||
from train import get_tokenizer, load_pretrained_checkpoint
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def load_vocoder(device):
|
||||
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
||||
model = BigVGANInference.from_pretrained(
|
||||
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
||||
)
|
||||
model = model.eval().to(device)
|
||||
return model
|
||||
|
||||
|
||||
def get_inference_prompt(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
polyphone=True,
|
||||
target_sample_rate=24000,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="vocos",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
num_buckets=200,
|
||||
min_secs=3,
|
||||
max_secs=40,
|
||||
):
|
||||
prompts_all = []
|
||||
|
||||
min_tokens = min_secs * target_sample_rate // hop_length
|
||||
max_tokens = max_secs * target_sample_rate // hop_length
|
||||
|
||||
batch_accum = [0] * num_buckets
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||
)
|
||||
|
||||
mel_spectrogram = MelSpec(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
|
||||
metainfo, desc="Processing prompts..."
|
||||
):
|
||||
# Audio
|
||||
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
||||
if ref_rms < target_rms:
|
||||
ref_audio = ref_audio * target_rms / ref_rms
|
||||
assert (
|
||||
ref_audio.shape[-1] > 5000
|
||||
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio)
|
||||
|
||||
# Text
|
||||
if len(prompt_text[-1].encode("utf-8")) == 1:
|
||||
prompt_text = prompt_text + " "
|
||||
text = [prompt_text + gt_text]
|
||||
if tokenizer == "pinyin":
|
||||
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
||||
else:
|
||||
text_list = text
|
||||
|
||||
# Duration, mel frame length
|
||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||
if use_truth_duration:
|
||||
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||
if gt_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
||||
gt_audio = resampler(gt_audio)
|
||||
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
||||
|
||||
# # test vocoder resynthesis
|
||||
# ref_audio = gt_audio
|
||||
else:
|
||||
ref_text_len = len(prompt_text.encode("utf-8"))
|
||||
gen_text_len = len(gt_text.encode("utf-8"))
|
||||
total_mel_len = ref_mel_len + int(
|
||||
ref_mel_len / ref_text_len * gen_text_len / speed
|
||||
)
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert (
|
||||
min_tokens <= total_mel_len <= max_tokens
|
||||
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
bucket_i = math.floor(
|
||||
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
||||
)
|
||||
|
||||
utts[bucket_i].append(utt)
|
||||
ref_rms_list[bucket_i].append(ref_rms)
|
||||
ref_mels[bucket_i].append(ref_mel)
|
||||
ref_mel_lens[bucket_i].append(ref_mel_len)
|
||||
total_mel_lens[bucket_i].append(total_mel_len)
|
||||
final_text_list[bucket_i].extend(text_list)
|
||||
|
||||
batch_accum[bucket_i] += total_mel_len
|
||||
|
||||
if batch_accum[bucket_i] >= infer_batch_size:
|
||||
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
batch_accum[bucket_i] = 0
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
ref_mels[bucket_i],
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
# add residual
|
||||
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||
if bucket_frames > 0:
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
# not only leave easy work for last workers
|
||||
random.seed(666)
|
||||
random.shuffle(prompts_all)
|
||||
|
||||
return prompts_all
|
||||
|
||||
|
||||
def padded_mel_batch(ref_mels):
|
||||
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||
padded_ref_mels = []
|
||||
for mel in ref_mels:
|
||||
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
||||
padded_ref_mels.append(padded_ref_mel)
|
||||
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
||||
return padded_ref_mels
|
||||
|
||||
|
||||
def get_seedtts_testset_metainfo(metalst):
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
f.close()
|
||||
metainfo = []
|
||||
for line in lines:
|
||||
if len(line.strip().split("|")) == 5:
|
||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
||||
elif len(line.strip().split("|")) == 4:
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
||||
return metainfo
|
||||
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
target_rms = 0.1
|
||||
|
||||
# rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
def main():
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
|
||||
parser = argparse.ArgumentParser(description="batch inference")
|
||||
|
||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
||||
parser.add_argument("-n", "--expname", required=True)
|
||||
parser.add_argument("-c", "--ckptstep", default=15000, type=int)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mel_spec_type",
|
||||
default="bigvgan",
|
||||
type=str,
|
||||
choices=["bigvgan", "vocos"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]
|
||||
)
|
||||
|
||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||
parser.add_argument("-o", "--odemethod", default="euler")
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
seed = args.seed
|
||||
dataset_name = args.dataset
|
||||
exp_name = args.expname
|
||||
ckpt_step = args.ckptstep
|
||||
|
||||
ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"
|
||||
ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt"
|
||||
|
||||
mel_spec_type = args.mel_spec_type
|
||||
tokenizer = args.tokenizer
|
||||
|
||||
nfe_step = args.nfestep
|
||||
ode_method = args.odemethod
|
||||
sway_sampling_coef = args.swaysampling
|
||||
|
||||
testset = args.testset
|
||||
|
||||
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
||||
cfg_strength = 2.0
|
||||
speed = 1.0
|
||||
use_truth_duration = False
|
||||
no_ref_audio = False
|
||||
|
||||
model_cls = DiT
|
||||
model_cfg = dict(
|
||||
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
||||
)
|
||||
metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst"
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
# path to save genereted wavs
|
||||
output_dir = (
|
||||
f"./"
|
||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
||||
f"_cfg{cfg_strength}_speed{speed}"
|
||||
f"{'_gt-dur' if use_truth_duration else ''}"
|
||||
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
||||
)
|
||||
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed=speed,
|
||||
tokenizer=tokenizer,
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
mel_spec_type=mel_spec_type,
|
||||
target_rms=target_rms,
|
||||
use_truth_duration=use_truth_duration,
|
||||
infer_batch_size=infer_batch_size,
|
||||
)
|
||||
|
||||
vocoder = load_vocoder(device)
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt")
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(
|
||||
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
||||
),
|
||||
mel_spec_kwargs=dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
# model = load_pretrained_checkpoint(model, ckpt_path)
|
||||
_ = load_checkpoint(
|
||||
ckpt_path,
|
||||
model=model,
|
||||
)
|
||||
model = model.eval().to(device)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# start batch inference
|
||||
accelerator.wait_for_everyone()
|
||||
start = time.time()
|
||||
|
||||
with accelerator.split_between_processes(prompts_all) as prompts:
|
||||
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
||||
(
|
||||
utts,
|
||||
ref_rms_list,
|
||||
ref_mels,
|
||||
ref_mel_lens,
|
||||
total_mel_lens,
|
||||
final_text_list,
|
||||
) = prompt
|
||||
ref_mels = ref_mels.to(device)
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=final_text_list,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
no_ref_audio=no_ref_audio,
|
||||
seed=seed,
|
||||
)
|
||||
# Final result
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
|
||||
if ref_rms_list[i] < target_rms:
|
||||
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
torchaudio.save(
|
||||
f"{output_dir}/{utts[i]}.wav",
|
||||
generated_wave,
|
||||
target_sample_rate,
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
timediff = time.time() - start
|
||||
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/wenetspeech4tts/TTS/f5-tts/optim.py
Symbolic link
1
egs/wenetspeech4tts/TTS/f5-tts/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
104
egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py
Normal file
104
egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py
Normal file
@ -0,0 +1,104 @@
|
||||
from typing import Callable, Dict, List, Sequence, Union
|
||||
|
||||
import torch
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset.collation import collate_audio
|
||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||
from lhotse.utils import ifnone
|
||||
|
||||
|
||||
class SpeechSynthesisDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
The PyTorch Dataset for the speech synthesis task.
|
||||
Each item in this dataset is a dict of:
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
'audio': (B x NumSamples) float tensor
|
||||
'features': (B x NumFrames x NumFeatures) float tensor
|
||||
'audio_lens': (B, ) int tensor
|
||||
'features_lens': (B, ) int tensor
|
||||
'text': List[str] of len B # when return_text=True
|
||||
'tokens': List[List[str]] # when return_tokens=True
|
||||
'speakers': List[str] of len B # when return_spk_ids=True
|
||||
'cut': List of Cuts # when return_cuts=True
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||
feature_input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
feature_transforms: Union[Sequence[Callable], Callable] = None,
|
||||
return_text: bool = True,
|
||||
return_tokens: bool = False,
|
||||
return_spk_ids: bool = False,
|
||||
return_cuts: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.cut_transforms = ifnone(cut_transforms, [])
|
||||
self.feature_input_strategy = feature_input_strategy
|
||||
|
||||
self.return_text = return_text
|
||||
self.return_tokens = return_tokens
|
||||
self.return_spk_ids = return_spk_ids
|
||||
self.return_cuts = return_cuts
|
||||
|
||||
if feature_transforms is None:
|
||||
feature_transforms = []
|
||||
elif not isinstance(feature_transforms, Sequence):
|
||||
feature_transforms = [feature_transforms]
|
||||
|
||||
assert all(
|
||||
isinstance(transform, Callable) for transform in feature_transforms
|
||||
), "Feature transforms must be Callable"
|
||||
self.feature_transforms = feature_transforms
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
|
||||
validate_for_tts(cuts)
|
||||
|
||||
for transform in self.cut_transforms:
|
||||
cuts = transform(cuts)
|
||||
|
||||
# audio, audio_lens = collate_audio(cuts)
|
||||
features, features_lens = self.feature_input_strategy(cuts)
|
||||
|
||||
for transform in self.feature_transforms:
|
||||
features = transform(features)
|
||||
|
||||
batch = {
|
||||
# "audio": audio,
|
||||
"features": features,
|
||||
# "audio_lens": audio_lens,
|
||||
"features_lens": features_lens,
|
||||
}
|
||||
|
||||
if self.return_text:
|
||||
# use normalized text
|
||||
# text = [cut.supervisions[0].normalized_text for cut in cuts]
|
||||
text = [cut.supervisions[0].text for cut in cuts]
|
||||
batch["text"] = text
|
||||
|
||||
if self.return_tokens:
|
||||
# tokens = [cut.tokens for cut in cuts]
|
||||
tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
|
||||
batch["tokens"] = tokens
|
||||
|
||||
if self.return_spk_ids:
|
||||
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
|
||||
|
||||
if self.return_cuts:
|
||||
batch["cut"] = [cut for cut in cuts]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def validate_for_tts(cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
for cut in cuts:
|
||||
assert (
|
||||
len(cut.supervisions) == 1
|
||||
), "Only the Cuts with single supervision are supported."
|
@ -47,10 +47,13 @@ from model.dit import DiT
|
||||
from model.utils import convert_char_to_pinyin
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
# from torch.cuda.amp import GradScaler
|
||||
from torch.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import TtsDataModule
|
||||
from utils import MetricsTracker
|
||||
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
@ -61,7 +64,7 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool # MetricsTracker
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
@ -340,7 +343,7 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 1,
|
||||
"log_interval": 100,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 10000,
|
||||
"env_info": get_env_info(),
|
||||
@ -411,12 +414,12 @@ def load_pretrained_checkpoint(
|
||||
):
|
||||
# model = model.to(dtype)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
if "ema_model_state_dict" in checkpoint:
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
|
||||
# patch for backward compatibility, 305e3ea
|
||||
for key in [
|
||||
@ -553,7 +556,7 @@ def prepare_input(batch: dict, device: torch.device):
|
||||
text_inputs = batch["text"]
|
||||
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
|
||||
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||
print(text_inputs)
|
||||
|
||||
mel_spec = batch["features"]
|
||||
mel_lengths = batch["features_lens"]
|
||||
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
||||
@ -591,22 +594,13 @@ def compute_loss(
|
||||
with torch.set_grad_enabled(is_training):
|
||||
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
|
||||
assert loss.requires_grad == is_training
|
||||
print(loss)
|
||||
# from accelerate import Accelerator
|
||||
# from accelerate.utils import DistributedDataParallelKwargs
|
||||
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# accelerator = Accelerator(
|
||||
# kwargs_handlers=[ddp_kwargs],
|
||||
# )
|
||||
# accelerator.backward(loss)
|
||||
# loss.backward()
|
||||
|
||||
info = MetricsTracker()
|
||||
# with warnings.catch_warnings():
|
||||
# warnings.simplefilter("ignore")
|
||||
# info["samples"] = mel_lengths.size(0)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["samples"] = mel_lengths.size(0)
|
||||
|
||||
# info["loss"] = loss.detach().cpu().item() * info["samples"]
|
||||
info["loss"] = loss.detach().cpu().item() * info["samples"]
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -633,7 +627,7 @@ def compute_validation_loss(
|
||||
tot_loss = tot_loss + loss_info
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
loss_value = tot_loss["loss"] / tot_loss["samples"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
@ -721,7 +715,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):
|
||||
with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -749,7 +743,7 @@ def train_one_epoch(
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
# optimizer.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
|
||||
@ -856,7 +850,7 @@ def train_one_epoch(
|
||||
# Calculate validation loss in Rank 0
|
||||
model.eval()
|
||||
logging.info("Computing validation loss")
|
||||
with torch.cuda.amp.autocast(dtype=dtype):
|
||||
with torch.amp.autocast("cuda", dtype=dtype):
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -876,7 +870,7 @@ def train_one_epoch(
|
||||
|
||||
model.train()
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
loss_value = tot_loss["loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
@ -944,7 +938,6 @@ def run(rank, world_size, args):
|
||||
|
||||
model = get_model(params)
|
||||
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
||||
@ -969,7 +962,7 @@ def run(rank, world_size, args):
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=False)
|
||||
|
||||
model_parameters = model.parameters()
|
||||
|
||||
@ -1046,7 +1039,9 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
|
||||
scaler = GradScaler(
|
||||
"cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
|
||||
)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1141,7 +1136,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
batch = train_dl.dataset[cuts]
|
||||
print(batch.keys())
|
||||
try:
|
||||
with torch.cuda.amp.autocast(dtype=dtype):
|
||||
with torch.amp.autocast("cuda", dtype=dtype):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
1
egs/wenetspeech4tts/TTS/f5-tts/utils.py
Symbolic link
1
egs/wenetspeech4tts/TTS/f5-tts/utils.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/matcha/utils.py
|
3
egs/wenetspeech4tts/TTS/infer_f5.sh
Normal file
3
egs/wenetspeech4tts/TTS/infer_f5.sh
Normal file
@ -0,0 +1,3 @@
|
||||
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||
|
||||
accelerate launch f5-tts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
28
egs/wenetspeech4tts/TTS/train_f5.sh
Normal file
28
egs/wenetspeech4tts/TTS/train_f5.sh
Normal file
@ -0,0 +1,28 @@
|
||||
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||
|
||||
install_flag=false
|
||||
if [ "$install_flag" = true ]; then
|
||||
echo "Installing packages..."
|
||||
|
||||
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# lhotse tensorboard kaldialign
|
||||
pip install -r requirements.txt
|
||||
pip install phonemizer pypinyin sentencepiece kaldialign matplotlib h5py
|
||||
|
||||
apt-get update && apt-get -y install festival espeak-ng mbrola
|
||||
else
|
||||
echo "Skipping installation."
|
||||
fi
|
||||
|
||||
world_size=8
|
||||
#world_size=1
|
||||
|
||||
exp_dir=exp/f5
|
||||
|
||||
# pip install -r f5-tts/requirements.txt
|
||||
python3 f5-tts/train.py --max-duration 300 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 8000 \
|
||||
--base-lr 1e-4 --warmup-steps 5000 --average-period 200 \
|
||||
--num-epochs 10 --start-epoch 1 --start-batch 20000 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size}
|
Loading…
x
Reference in New Issue
Block a user