From 10c099ac909f6406d35dae5a9db5c709e81c3c91 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:51:47 +0800 Subject: [PATCH] remove more unused code --- egs/ljspeech/TTS/matcha/{utils => }/audio.py | 0 .../TTS/matcha/export_onnx_hifigan.py | 1 - egs/ljspeech/TTS/matcha/inference.py | 4 +- egs/ljspeech/TTS/matcha/{utils => }/model.py | 0 .../TTS/matcha/models/baselightningmodule.py | 223 --------------- .../TTS/matcha/models/components/decoder.py | 1 - .../matcha/models/components/flow_matching.py | 5 - .../matcha/models/components/text_encoder.py | 6 +- egs/ljspeech/TTS/matcha/models/matcha_tts.py | 14 +- .../{utils => }/monotonic_align/.gitignore | 0 .../{utils => }/monotonic_align/__init__.py | 2 +- .../{utils => }/monotonic_align/core.pyx | 0 .../{utils => }/monotonic_align/setup.py | 0 egs/ljspeech/TTS/matcha/onnx_pretrained.py | 7 +- egs/ljspeech/TTS/matcha/train.py | 4 +- .../TTS/matcha/{utils2.py => utils.py} | 0 egs/ljspeech/TTS/matcha/utils/__init__.py | 6 - .../matcha/utils/generate_data_statistics.py | 123 --------- .../utils/get_durations_from_trained_model.py | 215 --------------- .../TTS/matcha/utils/instantiators.py | 60 ---- .../TTS/matcha/utils/logging_utils.py | 57 ---- egs/ljspeech/TTS/matcha/utils/pylogger.py | 29 -- egs/ljspeech/TTS/matcha/utils/rich_utils.py | 103 ------- egs/ljspeech/TTS/matcha/utils/utils.py | 261 ------------------ 24 files changed, 13 insertions(+), 1108 deletions(-) rename egs/ljspeech/TTS/matcha/{utils => }/audio.py (100%) rename egs/ljspeech/TTS/matcha/{utils => }/model.py (100%) delete mode 100644 egs/ljspeech/TTS/matcha/models/baselightningmodule.py rename egs/ljspeech/TTS/matcha/{utils => }/monotonic_align/.gitignore (100%) rename egs/ljspeech/TTS/matcha/{utils => }/monotonic_align/__init__.py (90%) rename egs/ljspeech/TTS/matcha/{utils => }/monotonic_align/core.pyx (100%) rename egs/ljspeech/TTS/matcha/{utils => }/monotonic_align/setup.py (100%) rename egs/ljspeech/TTS/matcha/{utils2.py => utils.py} (100%) delete mode 100644 egs/ljspeech/TTS/matcha/utils/__init__.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/instantiators.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/logging_utils.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/pylogger.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/rich_utils.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/utils.py diff --git a/egs/ljspeech/TTS/matcha/utils/audio.py b/egs/ljspeech/TTS/matcha/audio.py similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/audio.py rename to egs/ljspeech/TTS/matcha/audio.py diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py index 3b2ebf502..af54f4e89 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -5,7 +5,6 @@ from typing import Any, Dict import onnx import torch - from inference import load_vocoder diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 209fb86b4..89a6b33ae 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -2,17 +2,17 @@ import argparse import datetime as dt +import json import logging from pathlib import Path -import json import numpy as np import soundfile as sf import torch from matcha.hifigan.config import v1, v2, v3 from matcha.hifigan.denoiser import Denoiser -from tokenizer import Tokenizer from matcha.hifigan.models import Generator as HiFiGAN +from tokenizer import Tokenizer from tqdm.auto import tqdm from train import get_model, get_params diff --git a/egs/ljspeech/TTS/matcha/utils/model.py b/egs/ljspeech/TTS/matcha/model.py similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/model.py rename to egs/ljspeech/TTS/matcha/model.py diff --git a/egs/ljspeech/TTS/matcha/models/baselightningmodule.py b/egs/ljspeech/TTS/matcha/models/baselightningmodule.py deleted file mode 100644 index e80d2a5c9..000000000 --- a/egs/ljspeech/TTS/matcha/models/baselightningmodule.py +++ /dev/null @@ -1,223 +0,0 @@ -""" -This is a base lightning module that can be used to train a model. -The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. -""" -import inspect -from abc import ABC -from typing import Any, Dict - -import torch -from lightning import LightningModule -from lightning.pytorch.utilities import grad_norm - -from matcha import utils -from matcha.utils.utils import plot_tensor - -log = utils.get_pylogger(__name__) - - -class BaseLightningClass(LightningModule, ABC): - def update_data_statistics(self, data_statistics): - if data_statistics is None: - data_statistics = { - "mel_mean": 0.0, - "mel_std": 1.0, - } - - self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) - self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) - - def configure_optimizers(self) -> Any: - optimizer = self.hparams.optimizer(params=self.parameters()) - if self.hparams.scheduler not in (None, {}): - scheduler_args = {} - # Manage last epoch for exponential schedulers - if ( - "last_epoch" - in inspect.signature(self.hparams.scheduler.scheduler).parameters - ): - if hasattr(self, "ckpt_loaded_epoch"): - current_epoch = self.ckpt_loaded_epoch - 1 - else: - current_epoch = -1 - - scheduler_args.update({"optimizer": optimizer}) - scheduler = self.hparams.scheduler.scheduler(**scheduler_args) - scheduler.last_epoch = current_epoch - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "interval": self.hparams.scheduler.lightning_args.interval, - "frequency": self.hparams.scheduler.lightning_args.frequency, - "name": "learning_rate", - }, - } - - return {"optimizer": optimizer} - - def get_losses(self, batch): - x, x_lengths = batch["x"], batch["x_lengths"] - y, y_lengths = batch["y"], batch["y_lengths"] - spks = batch["spks"] - - dur_loss, prior_loss, diff_loss, *_ = self( - x=x, - x_lengths=x_lengths, - y=y, - y_lengths=y_lengths, - spks=spks, - out_size=self.out_size, - durations=batch["durations"], - ) - return { - "dur_loss": dur_loss, - "prior_loss": prior_loss, - "diff_loss": diff_loss, - } - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - self.ckpt_loaded_epoch = checkpoint[ - "epoch" - ] # pylint: disable=attribute-defined-outside-init - - def training_step(self, batch: Any, batch_idx: int): - loss_dict = self.get_losses(batch) - self.log( - "step", - float(self.global_step), - on_step=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - - self.log( - "sub_loss/train_dur_loss", - loss_dict["dur_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - self.log( - "sub_loss/train_prior_loss", - loss_dict["prior_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - self.log( - "sub_loss/train_diff_loss", - loss_dict["diff_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - - total_loss = sum(loss_dict.values()) - self.log( - "loss/train", - total_loss, - on_step=True, - on_epoch=True, - logger=True, - prog_bar=True, - sync_dist=True, - ) - - return {"loss": total_loss, "log": loss_dict} - - def validation_step(self, batch: Any, batch_idx: int): - loss_dict = self.get_losses(batch) - self.log( - "sub_loss/val_dur_loss", - loss_dict["dur_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - self.log( - "sub_loss/val_prior_loss", - loss_dict["prior_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - self.log( - "sub_loss/val_diff_loss", - loss_dict["diff_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - - total_loss = sum(loss_dict.values()) - self.log( - "loss/val", - total_loss, - on_step=True, - on_epoch=True, - logger=True, - prog_bar=True, - sync_dist=True, - ) - - return total_loss - - def on_validation_end(self) -> None: - if self.trainer.is_global_zero: - one_batch = next(iter(self.trainer.val_dataloaders)) - if self.current_epoch == 0: - log.debug("Plotting original samples") - for i in range(2): - y = one_batch["y"][i].unsqueeze(0).to(self.device) - self.logger.experiment.add_image( - f"original/{i}", - plot_tensor(y.squeeze().cpu()), - self.current_epoch, - dataformats="HWC", - ) - - log.debug("Synthesising...") - for i in range(2): - x = one_batch["x"][i].unsqueeze(0).to(self.device) - x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) - spks = ( - one_batch["spks"][i].unsqueeze(0).to(self.device) - if one_batch["spks"] is not None - else None - ) - output = self.synthesise( - x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks - ) - y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] - attn = output["attn"] - self.logger.experiment.add_image( - f"generated_enc/{i}", - plot_tensor(y_enc.squeeze().cpu()), - self.current_epoch, - dataformats="HWC", - ) - self.logger.experiment.add_image( - f"generated_dec/{i}", - plot_tensor(y_dec.squeeze().cpu()), - self.current_epoch, - dataformats="HWC", - ) - self.logger.experiment.add_image( - f"alignment/{i}", - plot_tensor(attn.squeeze().cpu()), - self.current_epoch, - dataformats="HWC", - ) - - def on_before_optimizer_step(self, optimizer): - self.log_dict( - {f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()} - ) diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py index 5850f2639..14d19f5d4 100644 --- a/egs/ljspeech/TTS/matcha/models/components/decoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/decoder.py @@ -7,7 +7,6 @@ import torch.nn.functional as F from conformer import ConformerBlock from diffusers.models.activations import get_activation from einops import pack, rearrange, repeat - from matcha.models.components.transformer import BasicTransformerBlock diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 5a7226b4f..997689b1c 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -2,13 +2,8 @@ from abc import ABC import torch import torch.nn.functional as F - from matcha.models.components.decoder import Decoder -# from matcha.utils.pylogger import get_pylogger - -# log = get_pylogger(__name__) - class BASECFM(torch.nn.Module, ABC): def __init__( diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index 68f8ad864..ca77cba51 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -5,11 +5,7 @@ import math import torch import torch.nn as nn from einops import rearrange - -# import matcha.utils as utils -from matcha.utils.model import sequence_mask - -# log = utils.get_pylogger(__name__) +from matcha.model import sequence_mask class LayerNorm(nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index b1525695f..330d1dc47 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -2,23 +2,17 @@ import datetime as dt import math import random +import matcha.monotonic_align as monotonic_align import torch - -import matcha.utils.monotonic_align as monotonic_align - -# from matcha import utils -# from matcha.models.baselightningmodule import BaseLightningClass -from matcha.models.components.flow_matching import CFM -from matcha.models.components.text_encoder import TextEncoder -from matcha.utils.model import ( +from matcha.model import ( denormalize, duration_loss, fix_len_compatibility, generate_path, sequence_mask, ) - -# log = utils.get_pylogger(__name__) +from matcha.models.components.flow_matching import CFM +from matcha.models.components.text_encoder import TextEncoder class MatchaTTS(torch.nn.Module): # 🍵 diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore rename to egs/ljspeech/TTS/matcha/monotonic_align/.gitignore diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py similarity index 90% rename from egs/ljspeech/TTS/matcha/utils/monotonic_align/__init__.py rename to egs/ljspeech/TTS/matcha/monotonic_align/__init__.py index eee6e0d47..58286bdd4 100644 --- a/egs/ljspeech/TTS/matcha/utils/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py @@ -1,7 +1,7 @@ import numpy as np import torch -from matcha.utils.monotonic_align.core import maximum_path_c +from matcha.monotonic_align.core import maximum_path_c def maximum_path(value, mask): diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/monotonic_align/core.pyx rename to egs/ljspeech/TTS/matcha/monotonic_align/core.pyx diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/monotonic_align/setup.py rename to egs/ljspeech/TTS/matcha/monotonic_align/setup.py diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 24955e881..3953d5d0a 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 +import datetime as dt import logging import onnxruntime as ort -import torch -from tokenizer import Tokenizer -import datetime as dt - import soundfile as sf +import torch from inference import load_vocoder +from tokenizer import Tokenizer class OnnxHifiGANModel: diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 7f41ab101..ce13e7e42 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -14,15 +14,15 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed +from matcha.model import fix_len_compatibility from matcha.models.matcha_tts import MatchaTTS from matcha.tokenizer import Tokenizer -from matcha.utils.model import fix_len_compatibility from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter from tts_datamodule import LJSpeechTtsDataModule -from utils2 import MetricsTracker +from utils import MetricsTracker from icefall.checkpoint import load_checkpoint, save_checkpoint from icefall.dist import cleanup_dist, setup_dist diff --git a/egs/ljspeech/TTS/matcha/utils2.py b/egs/ljspeech/TTS/matcha/utils.py similarity index 100% rename from egs/ljspeech/TTS/matcha/utils2.py rename to egs/ljspeech/TTS/matcha/utils.py diff --git a/egs/ljspeech/TTS/matcha/utils/__init__.py b/egs/ljspeech/TTS/matcha/utils/__init__.py deleted file mode 100644 index 311744a78..000000000 --- a/egs/ljspeech/TTS/matcha/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers -# from matcha.utils.logging_utils import log_hyperparameters -# from matcha.utils.pylogger import get_pylogger -# from matcha.utils.rich_utils import enforce_tags, print_config_tree -# from matcha.utils.utils import extras, get_metric_value, task_wrapper -from matcha.utils.utils import intersperse diff --git a/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py b/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py deleted file mode 100644 index 3028e7695..000000000 --- a/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py +++ /dev/null @@ -1,123 +0,0 @@ -r""" -The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it -when needed. - -Parameters from hparam.py will be used -""" -import argparse -import json -import os -import sys -from pathlib import Path - -import rootutils -import torch -from hydra import compose, initialize -from omegaconf import open_dict -from tqdm.auto import tqdm - -from matcha.data.text_mel_datamodule import TextMelDataModule -from matcha.utils.logging_utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -def compute_data_statistics( - data_loader: torch.utils.data.DataLoader, out_channels: int -): - """Generate data mean and standard deviation helpful in data normalisation - - Args: - data_loader (torch.utils.data.Dataloader): _description_ - out_channels (int): mel spectrogram channels - """ - total_mel_sum = 0 - total_mel_sq_sum = 0 - total_mel_len = 0 - - for batch in tqdm(data_loader, leave=False): - mels = batch["y"] - mel_lengths = batch["y_lengths"] - - total_mel_len += torch.sum(mel_lengths) - total_mel_sum += torch.sum(mels) - total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) - - data_mean = total_mel_sum / (total_mel_len * out_channels) - data_std = torch.sqrt( - (total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2) - ) - - return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} - - -def main(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "-i", - "--input-config", - type=str, - default="vctk.yaml", - help="The name of the yaml config file under configs/data", - ) - - parser.add_argument( - "-b", - "--batch-size", - type=int, - default="256", - help="Can have increased batch size for faster computation", - ) - - parser.add_argument( - "-f", - "--force", - action="store_true", - default=False, - required=False, - help="force overwrite the file", - ) - args = parser.parse_args() - output_file = Path(args.input_config).with_suffix(".json") - - if os.path.exists(output_file) and not args.force: - print("File already exists. Use -f to force overwrite") - sys.exit(1) - - with initialize(version_base="1.3", config_path="../../configs/data"): - cfg = compose( - config_name=args.input_config, return_hydra_config=True, overrides=[] - ) - - root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") - - with open_dict(cfg): - print(cfg) - del cfg["hydra"] - del cfg["_target_"] - cfg["data_statistics"] = None - cfg["seed"] = 1234 - cfg["batch_size"] = args.batch_size - cfg["train_filelist_path"] = str( - os.path.join(root_path, cfg["train_filelist_path"]) - ) - cfg["valid_filelist_path"] = str( - os.path.join(root_path, cfg["valid_filelist_path"]) - ) - cfg["load_durations"] = False - - text_mel_datamodule = TextMelDataModule(**cfg) - text_mel_datamodule.setup() - data_loader = text_mel_datamodule.train_dataloader() - log.info("Dataloader loaded! Now computing stats...") - params = compute_data_statistics(data_loader, cfg["n_feats"]) - print(params) - json.dump( - params, - open(output_file, "w"), - ) - - -if __name__ == "__main__": - main() diff --git a/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py b/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py deleted file mode 100644 index acc7eabd9..000000000 --- a/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py +++ /dev/null @@ -1,215 +0,0 @@ -r""" -The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it -when needed. - -Parameters from hparam.py will be used -""" -import argparse -import json -import os -import sys -from pathlib import Path - -import lightning -import numpy as np -import rootutils -import torch -from hydra import compose, initialize -from omegaconf import open_dict -from torch import nn -from tqdm.auto import tqdm - -from matcha.cli import get_device -from matcha.data.text_mel_datamodule import TextMelDataModule -from matcha.models.matcha_tts import MatchaTTS -from matcha.utils.logging_utils import pylogger -from matcha.utils.utils import get_phoneme_durations - -log = pylogger.get_pylogger(__name__) - - -def save_durations_to_folder( - attn: torch.Tensor, - x_length: int, - y_length: int, - filepath: str, - output_folder: Path, - text: str, -): - durations = attn.squeeze().sum(1)[:x_length].numpy() - durations_json = get_phoneme_durations(durations, text) - output = output_folder / Path(filepath).name.replace(".wav", ".npy") - with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: - json.dump(durations_json, f, indent=4, ensure_ascii=False) - - np.save(output, durations) - - -@torch.inference_mode() -def compute_durations( - data_loader: torch.utils.data.DataLoader, - model: nn.Module, - device: torch.device, - output_folder, -): - """Generate durations from the model for each datapoint and save it in a folder - - Args: - data_loader (torch.utils.data.DataLoader): Dataloader - model (nn.Module): MatchaTTS model - device (torch.device): GPU or CPU - """ - - for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"): - x, x_lengths = batch["x"], batch["x_lengths"] - y, y_lengths = batch["y"], batch["y_lengths"] - spks = batch["spks"] - x = x.to(device) - y = y.to(device) - x_lengths = x_lengths.to(device) - y_lengths = y_lengths.to(device) - spks = spks.to(device) if spks is not None else None - - _, _, _, attn = model( - x=x, - x_lengths=x_lengths, - y=y, - y_lengths=y_lengths, - spks=spks, - ) - attn = attn.cpu() - for i in range(attn.shape[0]): - save_durations_to_folder( - attn[i], - x_lengths[i].item(), - y_lengths[i].item(), - batch["filepaths"][i], - output_folder, - batch["x_texts"][i], - ) - - -def main(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "-i", - "--input-config", - type=str, - default="ljspeech.yaml", - help="The name of the yaml config file under configs/data", - ) - - parser.add_argument( - "-b", - "--batch-size", - type=int, - default="32", - help="Can have increased batch size for faster computation", - ) - - parser.add_argument( - "-f", - "--force", - action="store_true", - default=False, - required=False, - help="force overwrite the file", - ) - parser.add_argument( - "-c", - "--checkpoint_path", - type=str, - required=True, - help="Path to the checkpoint file to load the model from", - ) - - parser.add_argument( - "-o", - "--output-folder", - type=str, - default=None, - help="Output folder to save the data statistics", - ) - - parser.add_argument( - "--cpu", - action="store_true", - help="Use CPU for inference, not recommended (default: use GPU if available)", - ) - - args = parser.parse_args() - - with initialize(version_base="1.3", config_path="../../configs/data"): - cfg = compose( - config_name=args.input_config, return_hydra_config=True, overrides=[] - ) - - root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") - - with open_dict(cfg): - del cfg["hydra"] - del cfg["_target_"] - cfg["seed"] = 1234 - cfg["batch_size"] = args.batch_size - cfg["train_filelist_path"] = str( - os.path.join(root_path, cfg["train_filelist_path"]) - ) - cfg["valid_filelist_path"] = str( - os.path.join(root_path, cfg["valid_filelist_path"]) - ) - cfg["load_durations"] = False - - if args.output_folder is not None: - output_folder = Path(args.output_folder) - else: - output_folder = Path(cfg["train_filelist_path"]).parent / "durations" - - print(f"Output folder set to: {output_folder}") - - if os.path.exists(output_folder) and not args.force: - print("Folder already exists. Use -f to force overwrite") - sys.exit(1) - - output_folder.mkdir(parents=True, exist_ok=True) - - print( - f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}" - ) - print("Loading model...") - device = get_device(args) - model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) - - text_mel_datamodule = TextMelDataModule(**cfg) - text_mel_datamodule.setup() - try: - print("Computing stats for training set if exists...") - train_dataloader = text_mel_datamodule.train_dataloader() - compute_durations(train_dataloader, model, device, output_folder) - except lightning.fabric.utilities.exceptions.MisconfigurationException: - print("No training set found") - - try: - print("Computing stats for validation set if exists...") - val_dataloader = text_mel_datamodule.val_dataloader() - compute_durations(val_dataloader, model, device, output_folder) - except lightning.fabric.utilities.exceptions.MisconfigurationException: - print("No validation set found") - - try: - print("Computing stats for test set if exists...") - test_dataloader = text_mel_datamodule.test_dataloader() - compute_durations(test_dataloader, model, device, output_folder) - except lightning.fabric.utilities.exceptions.MisconfigurationException: - print("No test set found") - - print(f"[+] Done! Data statistics saved to: {output_folder}") - - -if __name__ == "__main__": - # Helps with generating durations for the dataset to train other architectures - # that cannot learn to align due to limited size of dataset - # Example usage: - # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model - # This will create a folder in data/processed_data/durations/ljspeech with the durations - main() diff --git a/egs/ljspeech/TTS/matcha/utils/instantiators.py b/egs/ljspeech/TTS/matcha/utils/instantiators.py deleted file mode 100644 index bde0c0d75..000000000 --- a/egs/ljspeech/TTS/matcha/utils/instantiators.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import List - -import hydra -from lightning import Callback -from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig - -from matcha.utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config. - - :param callbacks_cfg: A DictConfig object containing callback configurations. - :return: A list of instantiated callbacks. - """ - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("No callback configs found! Skipping..") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info( - f"Instantiating callback <{cb_conf._target_}>" - ) # pylint: disable=protected-access - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config. - - :param logger_cfg: A DictConfig object containing logger configurations. - :return: A list of instantiated loggers. - """ - logger: List[Logger] = [] - - if not logger_cfg: - log.warning("No logger configs found! Skipping...") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info( - f"Instantiating logger <{lg_conf._target_}>" - ) # pylint: disable=protected-access - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger diff --git a/egs/ljspeech/TTS/matcha/utils/logging_utils.py b/egs/ljspeech/TTS/matcha/utils/logging_utils.py deleted file mode 100644 index 2d2377eb2..000000000 --- a/egs/ljspeech/TTS/matcha/utils/logging_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any, Dict - -from lightning.pytorch.utilities import rank_zero_only -from omegaconf import OmegaConf - -from matcha.utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -@rank_zero_only -def log_hyperparameters(object_dict: Dict[str, Any]) -> None: - """Controls which config parts are saved by Lightning loggers. - - Additionally saves: - - Number of model parameters - - :param object_dict: A dictionary containing the following objects: - - `"cfg"`: A DictConfig object containing the main config. - - `"model"`: The Lightning model. - - `"trainer"`: The Lightning trainer. - """ - hparams = {} - - cfg = OmegaConf.to_container(object_dict["cfg"]) - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return - - hparams["model"] = cfg["model"] - - # save number of model parameters - hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - hparams["model/params/non_trainable"] = sum( - p.numel() for p in model.parameters() if not p.requires_grad - ) - - hparams["data"] = cfg["data"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - - # send hparams to all loggers - for logger in trainer.loggers: - logger.log_hyperparams(hparams) diff --git a/egs/ljspeech/TTS/matcha/utils/pylogger.py b/egs/ljspeech/TTS/matcha/utils/pylogger.py deleted file mode 100644 index a7ed7a961..000000000 --- a/egs/ljspeech/TTS/matcha/utils/pylogger.py +++ /dev/null @@ -1,29 +0,0 @@ -import logging - -from lightning.pytorch.utilities import rank_zero_only - - -def get_pylogger(name: str = __name__) -> logging.Logger: - """Initializes a multi-GPU-friendly python command line logger. - - :param name: The name of the logger, defaults to ``__name__``. - - :return: A logger object. - """ - logger = logging.getLogger(name) - - # this ensures all logging levels get marked with the rank zero decorator - # otherwise logs would get multiplied for each GPU process in multi-GPU setup - logging_levels = ( - "debug", - "info", - "warning", - "error", - "exception", - "fatal", - "critical", - ) - for level in logging_levels: - setattr(logger, level, rank_zero_only(getattr(logger, level))) - - return logger diff --git a/egs/ljspeech/TTS/matcha/utils/rich_utils.py b/egs/ljspeech/TTS/matcha/utils/rich_utils.py deleted file mode 100644 index d7fcd1aae..000000000 --- a/egs/ljspeech/TTS/matcha/utils/rich_utils.py +++ /dev/null @@ -1,103 +0,0 @@ -from pathlib import Path -from typing import Sequence - -import rich -import rich.syntax -import rich.tree -from hydra.core.hydra_config import HydraConfig -from lightning.pytorch.utilities import rank_zero_only -from omegaconf import DictConfig, OmegaConf, open_dict -from rich.prompt import Prompt - -from matcha.utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -@rank_zero_only -def print_config_tree( - cfg: DictConfig, - print_order: Sequence[str] = ( - "data", - "model", - "callbacks", - "logger", - "trainer", - "paths", - "extras", - ), - resolve: bool = False, - save_to_file: bool = False, -) -> None: - """Prints the contents of a DictConfig as a tree structure using the Rich library. - - :param cfg: A DictConfig composed by Hydra. - :param print_order: Determines in what order config components are printed. Default is ``("data", "model", - "callbacks", "logger", "trainer", "paths", "extras")``. - :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. - :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. - """ - style = "dim" - tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) - - queue = [] - - # add fields from `print_order` to queue - for field in print_order: - _ = ( - queue.append(field) - if field in cfg - else log.warning( - f"Field '{field}' not found in config. Skipping '{field}' config printing..." - ) - ) - - # add all the other fields to queue (not specified in `print_order`) - for field in cfg: - if field not in queue: - queue.append(field) - - # generate config tree from queue - for field in queue: - branch = tree.add(field, style=style, guide_style=style) - - config_group = cfg[field] - if isinstance(config_group, DictConfig): - branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) - else: - branch_content = str(config_group) - - branch.add(rich.syntax.Syntax(branch_content, "yaml")) - - # print config tree - rich.print(tree) - - # save config tree to file - if save_to_file: - with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: - rich.print(tree, file=file) - - -@rank_zero_only -def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config. - - :param cfg: A DictConfig composed by Hydra. - :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. - """ - if not cfg.get("tags"): - if "id" in HydraConfig().cfg.hydra.job: - raise ValueError("Specify tags before launching a multirun!") - - log.warning("No tags provided in config. Prompting user to input tags...") - tags = Prompt.ask("Enter a list of comma separated tags", default="dev") - tags = [t.strip() for t in tags.split(",") if t != ""] - - with open_dict(cfg): - cfg.tags = tags - - log.info(f"Tags: {cfg.tags}") - - if save_to_file: - with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: - rich.print(cfg.tags, file=file) diff --git a/egs/ljspeech/TTS/matcha/utils/utils.py b/egs/ljspeech/TTS/matcha/utils/utils.py deleted file mode 100644 index a54554263..000000000 --- a/egs/ljspeech/TTS/matcha/utils/utils.py +++ /dev/null @@ -1,261 +0,0 @@ -import os -import sys -import warnings -from importlib.util import find_spec -from math import ceil -from pathlib import Path -from typing import Any, Callable, Dict, Tuple - -import matplotlib.pyplot as plt -import numpy as np -import torch - -# from omegaconf import DictConfig - -# from matcha.utils import pylogger, rich_utils - -# log = pylogger.get_pylogger(__name__) - - -def extras(cfg: "DictConfig") -> None: - """Applies optional utilities before the task is started. - - Utilities: - - Ignoring python warnings - - Setting tags from command line - - Rich config printing - - :param cfg: A DictConfig object containing the config tree. - """ - # return if no `extras` config - if not cfg.get("extras"): - log.warning("Extras config not found! ") - return - - # disable python warnings - if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") - warnings.filterwarnings("ignore") - - # prompt user to input tags from command line if none are provided in the config - if cfg.extras.get("enforce_tags"): - log.info("Enforcing tags! ") - rich_utils.enforce_tags(cfg, save_to_file=True) - - # pretty print config tree using Rich library - if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") - rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) - - -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that controls the failure behavior when executing the task function. - - This wrapper can be used to: - - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - - save the exception to a `.log` file - - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - - etc. (adjust depending on your needs) - - Example: - ``` - @utils.task_wrapper - def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: - ... - return metric_dict, object_dict - ``` - - :param task_func: The task function to be wrapped. - - :return: The wrapped task function. - """ - - def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: - # execute the task - try: - metric_dict, object_dict = task_func(cfg=cfg) - - # things to do if exception occurs - except Exception as ex: - # save exception to `.log` file - log.exception("") - - # some hyperparameter combinations might be invalid or cause out-of-memory errors - # so when using hparam search plugins like Optuna, you might want to disable - # raising the below exception to avoid multirun failure - raise ex - - # things to always do after either success or exception - finally: - # display output dir path in terminal - log.info(f"Output dir: {cfg.paths.output_dir}") - - # always close wandb run (even if exception occurs so multirun won't fail) - if find_spec("wandb"): # check if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - return metric_dict, object_dict - - return wrap - - -def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: - """Safely retrieves value of the metric logged in LightningModule. - - :param metric_dict: A dict containing metric values. - :param metric_name: The name of the metric to retrieve. - :return: The value of the metric. - """ - if not metric_name: - log.info("Metric name is None! Skipping metric value retrieval...") - return None - - if metric_name not in metric_dict: - raise ValueError( - f"Metric value not found! \n" - "Make sure metric name logged in LightningModule is correct!\n" - "Make sure `optimized_metric` name in `hparams_search` config is correct!" - ) - - metric_value = metric_dict[metric_name].item() - log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") - - return metric_value - - -def intersperse(lst, item): - # Adds blank symbol - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def save_figure_to_numpy(fig): - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - return data - - -def plot_tensor(tensor): - plt.style.use("default") - fig, ax = plt.subplots(figsize=(12, 3)) - im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.tight_layout() - fig.canvas.draw() - data = save_figure_to_numpy(fig) - plt.close() - return data - - -def save_plot(tensor, savepath): - plt.style.use("default") - fig, ax = plt.subplots(figsize=(12, 3)) - im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.tight_layout() - fig.canvas.draw() - plt.savefig(savepath) - plt.close() - - -def to_numpy(tensor): - if isinstance(tensor, np.ndarray): - return tensor - elif isinstance(tensor, torch.Tensor): - return tensor.detach().cpu().numpy() - elif isinstance(tensor, list): - return np.array(tensor) - else: - raise TypeError("Unsupported type for conversion to numpy array") - - -def get_user_data_dir(appname="matcha_tts"): - """ - Args: - appname (str): Name of application - - Returns: - Path: path to user data directory - """ - - MATCHA_HOME = os.environ.get("MATCHA_HOME") - if MATCHA_HOME is not None: - ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) - elif sys.platform == "win32": - import winreg # pylint: disable=import-outside-toplevel - - key = winreg.OpenKey( - winreg.HKEY_CURRENT_USER, - r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", - ) - dir_, _ = winreg.QueryValueEx(key, "Local AppData") - ans = Path(dir_).resolve(strict=False) - elif sys.platform == "darwin": - ans = Path("~/Library/Application Support/").expanduser() - else: - ans = Path.home().joinpath(".local/share") - - final_path = ans.joinpath(appname) - final_path.mkdir(parents=True, exist_ok=True) - return final_path - - -def assert_model_downloaded(checkpoint_path, url, use_wget=True): - import gdown - import wget - - if Path(checkpoint_path).exists(): - log.debug(f"[+] Model already present at {checkpoint_path}!") - print(f"[+] Model already present at {checkpoint_path}!") - return - log.info(f"[-] Model not found at {checkpoint_path}! Will download it") - print(f"[-] Model not found at {checkpoint_path}! Will download it") - checkpoint_path = str(checkpoint_path) - if not use_wget: - gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) - else: - wget.download(url=url, out=checkpoint_path) - - -def get_phoneme_durations(durations, phones): - prev = durations[0] - merged_durations = [] - # Convolve with stride 2 - for i in range(1, len(durations), 2): - if i == len(durations) - 2: - # if it is last take full value - next_half = durations[i + 1] - else: - next_half = ceil(durations[i + 1] / 2) - - curr = prev + durations[i] + next_half - prev = durations[i + 1] - next_half - merged_durations.append(curr) - - assert len(phones) == len(merged_durations) - assert len(merged_durations) == (len(durations) - 1) // 2 - - merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) - start = torch.tensor(0) - duration_json = [] - for i, duration in enumerate(merged_durations): - duration_json.append( - { - phones[i]: { - "starttime": start.item(), - "endtime": duration.item(), - "duration": duration.item() - start.item(), - } - } - ) - start = duration - - assert list(duration_json[-1].values())[0]["endtime"] == sum( - durations - ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" - return duration_json