From 511f63b551ef0edb771850e8e5ad2cd34d44418c Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Sun, 22 Dec 2024 18:48:26 +0800 Subject: [PATCH] using official repo --- egs/wenetspeech4tts/TTS/f5-tts/train.py | 77 ++++++------ .../TTS/f5-tts/tts_datamodule.py | 118 ++++++++++-------- 2 files changed, 100 insertions(+), 95 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 354b421b6..880e3748c 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -44,7 +44,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model.cfm import CFM from model.dit import DiT -from model.utils import MelSpec +from model.utils import convert_char_to_pinyin from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler @@ -151,7 +151,7 @@ def get_parser(): parser.add_argument( "--exp-dir", - type=str, + type=Path, default="exp/valle_dev", help="""The experiment dir. It specifies the directory where all training related @@ -162,7 +162,7 @@ def get_parser(): parser.add_argument( "--tokens", type=str, - default="ft-tts/vocab.txt", + default="f5-tts/vocab.txt", help="Path to the unique text tokens file", ) @@ -409,7 +409,7 @@ def get_model(params): def load_pretrained_checkpoint( model, ckpt_path, device: str = "cpu", dtype=torch.float32 ): - model = model.to(dtype) + # model = model.to(dtype) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) checkpoint["model_state_dict"] = { @@ -548,14 +548,15 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) -def prepare_input(batch: dict, tokenizer, device: torch.device): +def prepare_input(batch: dict, device: torch.device): """Parse batch data""" - print(batch.keys()) - print(batch) text_inputs = batch["text"] - mel_spec = batch["mel"].permute(0, 2, 1) - mel_lengths = batch["mel_lengths"] - return text_inputs, mel_spec, mel_lengths + # texts.extend(convert_char_to_pinyin([text], polyphone=true)) + text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) + print(text_inputs) + mel_spec = batch["features"] + mel_lengths = batch["features_lens"] + return text_inputs, mel_spec.to(device), mel_lengths.to(device) def compute_loss( @@ -584,34 +585,28 @@ def compute_loss( values >= 1.0 are fully warmed up and have all modules present. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device - (mel_spec, text_inputs, mel_lengths) = prepare_input(batch, device) + (text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device) # at entry, TextTokens is (N, P) - assert text_inputs.ndim == 2 - assert mel_spec.ndim == 3 with torch.set_grad_enabled(is_training): loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths) assert loss.requires_grad == is_training + print(loss) + # from accelerate import Accelerator + # from accelerate.utils import DistributedDataParallelKwargs + # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + # accelerator = Accelerator( + # kwargs_handlers=[ddp_kwargs], + # ) + # accelerator.backward(loss) + # loss.backward() info = MetricsTracker() - exit(0) # with warnings.catch_warnings(): # warnings.simplefilter("ignore") - # info["frames"] = (audio_features_lens).sum().item() - # info["utterances"] = text_tokens.size(0) + # info["samples"] = mel_lengths.size(0) - # # # Note: We use reduction=sum while computing the loss. - # # info["loss"] = loss.detach().cpu().item() - # # for metric in metrics: - # # info[metric] = metrics[metric].detach().cpu().item() - # # del metrics - # # Note: We use reduction=sum while computing the loss. - # info["loss"] = loss.detach().cpu().item() * info["frames"] - - # for i in range(len(loss_list)): - # info[f"loss_{i}"] = loss_list[i].detach().cpu().item() * info["frames"] - # for i in range(len(acc_list)): - # info[f"acc_{i}"] = acc_list[i] * info["frames"] + # info["loss"] = loss.detach().cpu().item() * info["samples"] return loss, info @@ -734,6 +729,7 @@ def train_one_epoch( batch=batch, is_training=True, ) + # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( 1 / params.reset_interval @@ -753,7 +749,9 @@ def train_one_epoch( scaler.step(optimizer) scaler.update() - optimizer.zero_grad() + # optimizer.zero_grad() + # loss.backward() + # optimizer.step() for k in range(params.accumulate_grad_steps): if isinstance(scheduler, Eden): @@ -926,12 +924,7 @@ def run(rank, world_size, args): logging.info("Training started") if args.tensorboard and rank == 0: - if params.train_stage: - tb_writer = SummaryWriter( - log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}" - ) - else: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None @@ -950,7 +943,7 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_model(params) - model = load_pretrained_checkpoint(model, params.pretrained_model_path) + # model = load_pretrained_checkpoint(model, params.pretrained_model_path) model = model.to(device) @@ -968,6 +961,7 @@ def run(rank, world_size, args): model_avg = copy.deepcopy(model).to(torch.float64) assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) @@ -1029,7 +1023,7 @@ def run(rank, world_size, args): dataset = TtsDataModule(args) train_cuts = dataset.train_cuts() - valid_cuts = dataset.dev_cuts() + valid_cuts = dataset.valid_cuts() train_cuts = filter_short_and_long_utterances( train_cuts, params.filter_min_duration, params.filter_max_duration @@ -1041,7 +1035,7 @@ def run(rank, world_size, args): train_dl = dataset.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_dl = dataset.dev_dataloaders(valid_cuts) + valid_dl = dataset.valid_dataloaders(valid_cuts) if params.oom_check: scan_pessimistic_batches_for_oom( @@ -1136,7 +1130,7 @@ def scan_pessimistic_batches_for_oom( "Sanity check -- see if any of the batches in epoch 1 would cause OOM." ) batches, crit_values = find_pessimistic_batches(train_dl.sampler) - + print(23333) dtype = torch.float32 if params.dtype in ["bfloat16", "bf16"]: dtype = torch.bfloat16 @@ -1145,16 +1139,17 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] + print(batch.keys()) try: with torch.cuda.amp.autocast(dtype=dtype): - _, loss, _ = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, tokenizer=tokenizer, batch=batch, is_training=True, ) - loss.backward() + loss.backward(retain_graph=True) optimizer.zero_grad() except Exception as e: if "CUDA out of memory" in str(e): diff --git a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py index 7a665a54c..b544f1d96 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py @@ -24,21 +24,22 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from fbank import MatchaFbank, MatchaFbankConfig + +# from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset, CutConcatenate, CutMix, DynamicBucketingSampler, PrecomputedFeatures, SimpleCutSampler, - SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, OnTheFlyFeatures, ) from lhotse.utils import fix_random_seed +from speech_synthesis import SpeechSynthesisDataset # noqa F401 from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -174,29 +175,32 @@ class TtsDataModule: """ logging.info("About to create train dataset") train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, + return_text=True, + return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." ) + # sampling_rate = 22050 + # config = MatchaFbankConfig( + # n_fft=1024, + # n_mels=80, + # sampling_rate=sampling_rate, + # hop_length=256, + # win_length=1024, + # f_min=0, + # f_max=8000, + # ) + # train = SpeechSynthesisDataset( + # return_text=True, + # return_tokens=False, + # feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + # return_cuts=self.args.return_cuts, + # ) if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") @@ -242,26 +246,29 @@ class TtsDataModule: def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." ) + # sampling_rate = 22050 + # config = MatchaFbankConfig( + # n_fft=1024, + # n_mels=80, + # sampling_rate=sampling_rate, + # hop_length=256, + # win_length=1024, + # f_min=0, + # f_max=8000, + # ) + # validate = SpeechSynthesisDataset( + # return_text=True, + # return_tokens=False, + # feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + # return_cuts=self.args.return_cuts, + # ) else: validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, + return_text=True, + return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -286,26 +293,29 @@ class TtsDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.info("About to create test dataset") if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = MatchaFbankConfig( - n_fft=1024, - n_mels=80, - sampling_rate=sampling_rate, - hop_length=256, - win_length=1024, - f_min=0, - f_max=8000, - ) - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - return_cuts=self.args.return_cuts, + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." ) + # sampling_rate = 22050 + # config = MatchaFbankConfig( + # n_fft=1024, + # n_mels=80, + # sampling_rate=sampling_rate, + # hop_length=256, + # win_length=1024, + # f_min=0, + # f_max=8000, + # ) + # test = SpeechSynthesisDataset( + # return_text=True, + # return_tokens=False, + # feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + # return_cuts=self.args.return_cuts, + # ) else: test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, + return_text=True, + return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, )