diff --git a/egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py b/egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py new file mode 100644 index 000000000..d70df9626 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py @@ -0,0 +1,411 @@ +import argparse +import math +import os +import random +import time + +# import bigvan +# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/") +import torch +import torch.nn.functional as F +import torchaudio +from accelerate import Accelerator + +# from importlib.resources import files +# import sys +# sys.path.append(f"/home/yuekaiz/BigVGAN/") +# from bigvgan import BigVGAN +from bigvganinference import BigVGANInference + +# from f5_tts.eval.utils_eval import ( +# get_inference_prompt, +# get_librispeech_test_clean_metainfo, +# get_seedtts_testset_metainfo, +# ) +# from f5_tts.infer.utils_infer import load_vocoder +from model.cfm import CFM +from model.dit import DiT +from model.modules import MelSpec +from model.utils import convert_char_to_pinyin +from tqdm import tqdm +from train import get_tokenizer, load_pretrained_checkpoint + +from icefall.checkpoint import load_checkpoint + + +def load_vocoder(device): + # huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x + model = BigVGANInference.from_pretrained( + "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False + ) + model = model.eval().to(device) + return model + + +def get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="vocos", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio, ref_sr = torchaudio.load(prompt_wav) + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) + if ref_rms < target_rms: + ref_audio = ref_audio * target_rms / ref_rms + assert ( + ref_audio.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio) + + # Text + if len(prompt_text[-1].encode("utf-8")) == 1: + prompt_text = prompt_text + " " + text = [prompt_text + gt_text] + if tokenizer == "pinyin": + text_list = convert_char_to_pinyin(text, polyphone=polyphone) + else: + text_list = text + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + final_text_list[bucket_i].extend(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}") + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def get_seedtts_testset_metainfo(metalst): + f = open(metalst) + lines = f.readlines() + f.close() + metainfo = [] + for line in lines: + if len(line.strip().split("|")) == 5: + utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") + elif len(line.strip().split("|")) == 4: + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) + metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) + return metainfo + + +accelerator = Accelerator() +device = f"cuda:{accelerator.process_index}" + + +# --------------------- Dataset Settings -------------------- # + +target_sample_rate = 24000 +n_mel_channels = 100 +hop_length = 256 +win_length = 1024 +n_fft = 1024 +target_rms = 0.1 + +# rel_path = str(files("f5_tts").joinpath("../../")) + + +def main(): + # ---------------------- infer setting ---------------------- # + + parser = argparse.ArgumentParser(description="batch inference") + + parser.add_argument("-s", "--seed", default=None, type=int) + parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN") + parser.add_argument("-n", "--expname", required=True) + parser.add_argument("-c", "--ckptstep", default=15000, type=int) + parser.add_argument( + "-m", + "--mel_spec_type", + default="bigvgan", + type=str, + choices=["bigvgan", "vocos"], + ) + parser.add_argument( + "-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"] + ) + + parser.add_argument("-nfe", "--nfestep", default=32, type=int) + parser.add_argument("-o", "--odemethod", default="euler") + parser.add_argument("-ss", "--swaysampling", default=-1, type=float) + + parser.add_argument("-t", "--testset", required=True) + + args = parser.parse_args() + + seed = args.seed + dataset_name = args.dataset + exp_name = args.expname + ckpt_step = args.ckptstep + + ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt" + ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt" + + mel_spec_type = args.mel_spec_type + tokenizer = args.tokenizer + + nfe_step = args.nfestep + ode_method = args.odemethod + sway_sampling_coef = args.swaysampling + + testset = args.testset + + infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) + cfg_strength = 2.0 + speed = 1.0 + use_truth_duration = False + no_ref_audio = False + + model_cls = DiT + model_cfg = dict( + dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 + ) + metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst" + metainfo = get_seedtts_testset_metainfo(metalst) + + # path to save genereted wavs + output_dir = ( + f"./" + f"results/{exp_name}_{ckpt_step}/{testset}/" + f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}" + f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" + f"_cfg{cfg_strength}_speed{speed}" + f"{'_gt-dur' if use_truth_duration else ''}" + f"{'_no-ref-audio' if no_ref_audio else ''}" + ) + + prompts_all = get_inference_prompt( + metainfo, + speed=speed, + tokenizer=tokenizer, + target_sample_rate=target_sample_rate, + n_mel_channels=n_mel_channels, + hop_length=hop_length, + mel_spec_type=mel_spec_type, + target_rms=target_rms, + use_truth_duration=use_truth_duration, + infer_batch_size=infer_batch_size, + ) + + vocoder = load_vocoder(device) + + # Tokenizer + vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt") + + # Model + model = CFM( + transformer=model_cls( + **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), + mel_spec_kwargs=dict( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ), + odeint_kwargs=dict( + method=ode_method, + ), + vocab_char_map=vocab_char_map, + ).to(device) + + dtype = torch.float32 if mel_spec_type == "bigvgan" else None + # model = load_pretrained_checkpoint(model, ckpt_path) + _ = load_checkpoint( + ckpt_path, + model=model, + ) + model = model.eval().to(device) + + if not os.path.exists(output_dir) and accelerator.is_main_process: + os.makedirs(output_dir) + + # start batch inference + accelerator.wait_for_everyone() + start = time.time() + + with accelerator.split_between_processes(prompts_all) as prompts: + for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): + ( + utts, + ref_rms_list, + ref_mels, + ref_mel_lens, + total_mel_lens, + final_text_list, + ) = prompt + ref_mels = ref_mels.to(device) + ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + + # Inference + with torch.inference_mode(): + generated, _ = model.sample( + cond=ref_mels, + text=final_text_list, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + no_ref_audio=no_ref_audio, + seed=seed, + ) + # Final result + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) + if mel_spec_type == "vocos": + generated_wave = vocoder.decode(gen_mel_spec).cpu() + elif mel_spec_type == "bigvgan": + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + + if ref_rms_list[i] < target_rms: + generated_wave = generated_wave * ref_rms_list[i] / target_rms + torchaudio.save( + f"{output_dir}/{utts[i]}.wav", + generated_wave, + target_sample_rate, + ) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + timediff = time.time() - start + print(f"Done batch inference in {timediff / 60 :.2f} minutes.") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/optim.py b/egs/wenetspeech4tts/TTS/f5-tts/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py new file mode 100644 index 000000000..57f677fcb --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py @@ -0,0 +1,104 @@ +from typing import Callable, Dict, List, Sequence, Union + +import torch +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.collation import collate_audio +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import ifnone + + +class SpeechSynthesisDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech synthesis task. + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'audio': (B x NumSamples) float tensor + 'features': (B x NumFrames x NumFeatures) float tensor + 'audio_lens': (B, ) int tensor + 'features_lens': (B, ) int tensor + 'text': List[str] of len B # when return_text=True + 'tokens': List[List[str]] # when return_tokens=True + 'speakers': List[str] of len B # when return_spk_ids=True + 'cut': List of Cuts # when return_cuts=True + } + """ + + def __init__( + self, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + feature_input_strategy: BatchIO = PrecomputedFeatures(), + feature_transforms: Union[Sequence[Callable], Callable] = None, + return_text: bool = True, + return_tokens: bool = False, + return_spk_ids: bool = False, + return_cuts: bool = False, + ) -> None: + super().__init__() + + self.cut_transforms = ifnone(cut_transforms, []) + self.feature_input_strategy = feature_input_strategy + + self.return_text = return_text + self.return_tokens = return_tokens + self.return_spk_ids = return_spk_ids + self.return_cuts = return_cuts + + if feature_transforms is None: + feature_transforms = [] + elif not isinstance(feature_transforms, Sequence): + feature_transforms = [feature_transforms] + + assert all( + isinstance(transform, Callable) for transform in feature_transforms + ), "Feature transforms must be Callable" + self.feature_transforms = feature_transforms + + def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: + validate_for_tts(cuts) + + for transform in self.cut_transforms: + cuts = transform(cuts) + + # audio, audio_lens = collate_audio(cuts) + features, features_lens = self.feature_input_strategy(cuts) + + for transform in self.feature_transforms: + features = transform(features) + + batch = { + # "audio": audio, + "features": features, + # "audio_lens": audio_lens, + "features_lens": features_lens, + } + + if self.return_text: + # use normalized text + # text = [cut.supervisions[0].normalized_text for cut in cuts] + text = [cut.supervisions[0].text for cut in cuts] + batch["text"] = text + + if self.return_tokens: + # tokens = [cut.tokens for cut in cuts] + tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts] + batch["tokens"] = tokens + + if self.return_spk_ids: + batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] + + if self.return_cuts: + batch["cut"] = [cut for cut in cuts] + + return batch + + +def validate_for_tts(cuts: CutSet) -> None: + validate(cuts) + for cut in cuts: + assert ( + len(cut.supervisions) == 1 + ), "Only the Cuts with single supervision are supported." diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 880e3748c..3009235c4 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -47,10 +47,13 @@ from model.dit import DiT from model.utils import convert_char_to_pinyin from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler + +# from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tts_datamodule import TtsDataModule +from utils import MetricsTracker from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl @@ -61,7 +64,7 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import AttributeDict, setup_logger, str2bool # MetricsTracker LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -340,7 +343,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 1, + "log_interval": 100, "reset_interval": 200, "valid_interval": 10000, "env_info": get_env_info(), @@ -411,12 +414,12 @@ def load_pretrained_checkpoint( ): # model = model.to(dtype) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) - - checkpoint["model_state_dict"] = { - k.replace("ema_model.", ""): v - for k, v in checkpoint["ema_model_state_dict"].items() - if k not in ["initted", "step"] - } + if "ema_model_state_dict" in checkpoint: + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } # patch for backward compatibility, 305e3ea for key in [ @@ -553,7 +556,7 @@ def prepare_input(batch: dict, device: torch.device): text_inputs = batch["text"] # texts.extend(convert_char_to_pinyin([text], polyphone=true)) text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) - print(text_inputs) + mel_spec = batch["features"] mel_lengths = batch["features_lens"] return text_inputs, mel_spec.to(device), mel_lengths.to(device) @@ -591,22 +594,13 @@ def compute_loss( with torch.set_grad_enabled(is_training): loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths) assert loss.requires_grad == is_training - print(loss) - # from accelerate import Accelerator - # from accelerate.utils import DistributedDataParallelKwargs - # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - # accelerator = Accelerator( - # kwargs_handlers=[ddp_kwargs], - # ) - # accelerator.backward(loss) - # loss.backward() info = MetricsTracker() - # with warnings.catch_warnings(): - # warnings.simplefilter("ignore") - # info["samples"] = mel_lengths.size(0) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["samples"] = mel_lengths.size(0) - # info["loss"] = loss.detach().cpu().item() * info["samples"] + info["loss"] = loss.detach().cpu().item() * info["samples"] return loss, info @@ -633,7 +627,7 @@ def compute_validation_loss( tot_loss = tot_loss + loss_info if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss["loss"] / tot_loss["frames"] + loss_value = tot_loss["loss"] / tot_loss["samples"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value @@ -721,7 +715,7 @@ def train_one_epoch( batch_size = len(batch["text"]) try: - with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): + with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled): loss, loss_info = compute_loss( params=params, model=model, @@ -749,7 +743,7 @@ def train_one_epoch( scaler.step(optimizer) scaler.update() - # optimizer.zero_grad() + optimizer.zero_grad() # loss.backward() # optimizer.step() @@ -856,7 +850,7 @@ def train_one_epoch( # Calculate validation loss in Rank 0 model.eval() logging.info("Computing validation loss") - with torch.cuda.amp.autocast(dtype=dtype): + with torch.amp.autocast("cuda", dtype=dtype): valid_info = compute_validation_loss( params=params, model=model, @@ -876,7 +870,7 @@ def train_one_epoch( model.train() - loss_value = tot_loss["loss"] / tot_loss["frames"] + loss_value = tot_loss["loss"] / tot_loss["samples"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -944,7 +938,6 @@ def run(rank, world_size, args): model = get_model(params) # model = load_pretrained_checkpoint(model, params.pretrained_model_path) - model = model.to(device) with open(f"{params.exp_dir}/model.txt", "w") as f: @@ -969,7 +962,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=False) model_parameters = model.parameters() @@ -1046,7 +1039,9 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) + scaler = GradScaler( + "cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 + ) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1141,7 +1136,7 @@ def scan_pessimistic_batches_for_oom( batch = train_dl.dataset[cuts] print(batch.keys()) try: - with torch.cuda.amp.autocast(dtype=dtype): + with torch.amp.autocast("cuda", dtype=dtype): loss, loss_info = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech4tts/TTS/f5-tts/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/utils.py new file mode 120000 index 000000000..ceaaea196 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/utils.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/infer_f5.sh b/egs/wenetspeech4tts/TTS/infer_f5.sh new file mode 100644 index 000000000..a2decbd78 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/infer_f5.sh @@ -0,0 +1,3 @@ +export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha + +accelerate launch f5-tts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 diff --git a/egs/wenetspeech4tts/TTS/train_f5.sh b/egs/wenetspeech4tts/TTS/train_f5.sh new file mode 100644 index 000000000..f29563531 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/train_f5.sh @@ -0,0 +1,28 @@ +export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha + +install_flag=false +if [ "$install_flag" = true ]; then + echo "Installing packages..." + + pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html + # pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html + # lhotse tensorboard kaldialign + pip install -r requirements.txt + pip install phonemizer pypinyin sentencepiece kaldialign matplotlib h5py + + apt-get update && apt-get -y install festival espeak-ng mbrola +else + echo "Skipping installation." +fi + +world_size=8 +#world_size=1 + +exp_dir=exp/f5 + +# pip install -r f5-tts/requirements.txt +python3 f5-tts/train.py --max-duration 300 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 8000 \ + --base-lr 1e-4 --warmup-steps 5000 --average-period 200 \ + --num-epochs 10 --start-epoch 1 --start-batch 20000 \ + --exp-dir ${exp_dir} --world-size ${world_size}