From ed569a938ac578e436f2b433e6f4dbde07fe91b9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:20:21 +0800 Subject: [PATCH] remove more unused code --- egs/ljspeech/TTS/matcha/test-train.py | 159 -------------------------- egs/ljspeech/TTS/matcha/train-orig.py | 122 -------------------- 2 files changed, 281 deletions(-) delete mode 100644 egs/ljspeech/TTS/matcha/test-train.py delete mode 100644 egs/ljspeech/TTS/matcha/train-orig.py diff --git a/egs/ljspeech/TTS/matcha/test-train.py b/egs/ljspeech/TTS/matcha/test-train.py deleted file mode 100644 index f41ee4eae..000000000 --- a/egs/ljspeech/TTS/matcha/test-train.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) - - -import torch - - -from icefall.utils import AttributeDict -from matcha.models.matcha_tts import MatchaTTS -from matcha.data.text_mel_datamodule import TextMelDataModule - - -def _get_data_params() -> AttributeDict: - params = AttributeDict( - { - "name": "ljspeech", - "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", - "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - "batch_size": 32, - "num_workers": 3, - "pin_memory": False, - "cleaners": ["english_cleaners2"], - "add_blank": True, - "n_spks": 1, - "n_fft": 1024, - "n_feats": 80, - "sample_rate": 22050, - "hop_length": 256, - "win_length": 1024, - "f_min": 0, - "f_max": 8000, - "seed": 1234, - "load_durations": False, - "data_statistics": AttributeDict( - { - "mel_mean": -5.517028331756592, - "mel_std": 2.0643954277038574, - } - ), - } - ) - return params - - -def _get_model_params() -> AttributeDict: - n_feats = 80 - filter_channels_dp = 256 - encoder_params_p_dropout = 0.1 - params = AttributeDict( - { - "n_vocab": 178, - "n_spks": 1, # for ljspeech. - "spk_emb_dim": 64, - "n_feats": n_feats, - "out_size": None, # or use 172 - "prior_loss": True, - "use_precomputed_durations": False, - "encoder": AttributeDict( - { - "encoder_type": "RoPE Encoder", # not used - "encoder_params": AttributeDict( - { - "n_feats": n_feats, - "n_channels": 192, - "filter_channels": 768, - "filter_channels_dp": filter_channels_dp, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - "spk_emb_dim": 64, - "n_spks": 1, - "prenet": True, - } - ), - "duration_predictor_params": AttributeDict( - { - "filter_channels_dp": filter_channels_dp, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - } - ), - } - ), - "decoder": AttributeDict( - { - "channels": [256, 256], - "dropout": 0.05, - "attention_head_dim": 64, - "n_blocks": 1, - "num_mid_blocks": 2, - "num_heads": 2, - "act_fn": "snakebeta", - } - ), - "cfm": AttributeDict( - { - "name": "CFM", - "solver": "euler", - "sigma_min": 1e-4, - } - ), - "optimizer": AttributeDict( - { - "lr": 1e-4, - "weight_decay": 0.0, - } - ), - } - ) - - return params - - -def get_params(): - params = AttributeDict( - { - "model": _get_model_params(), - "data": _get_data_params(), - } - ) - return params - - -def get_model(params): - m = MatchaTTS(**params.model) - return m - - -def main(): - params = get_params() - - data_module = TextMelDataModule(hparams=params.data) - if False: - for b in data_module.train_dataloader(): - assert isinstance(b, dict) - # b.keys() - # ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations'] - # x: [batch_size, 289], torch.int64 - # x_lengths: [batch_size], torch.int64 - # y: [batch_size, n_feats, num_frames], torch.float32 - # y_lengths: [batch_size], torch.int64 - # spks: None - # filepaths: list, (batch_size,) - # x_texts: list, (batch_size,) - # durations: None - - m = get_model(params) - print(m) - - num_param = sum([p.numel() for p in m.parameters()]) - print(f"Number of parameters: {num_param}") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/ljspeech/TTS/matcha/train-orig.py b/egs/ljspeech/TTS/matcha/train-orig.py deleted file mode 100644 index d1d64c6c4..000000000 --- a/egs/ljspeech/TTS/matcha/train-orig.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import hydra -import lightning as L -import rootutils -from lightning import Callback, LightningDataModule, LightningModule, Trainer -from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig - -from matcha import utils - -rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# ------------------------------------------------------------------------------------ # -# the setup_root above is equivalent to: -# - adding project root dir to PYTHONPATH -# (so you don't need to force user to install project as a package) -# (necessary before importing any local modules e.g. `from src import utils`) -# - setting up PROJECT_ROOT environment variable -# (which is used as a base for paths in "configs/paths/default.yaml") -# (this way all filepaths are the same no matter where you run the code) -# - loading environment variables from ".env" in root dir -# -# you can remove it if you: -# 1. either install project as a package or move entry files to project root dir -# 2. set `root_dir` to "." in "configs/paths/default.yaml" -# -# more info: https://github.com/ashleve/rootutils -# ------------------------------------------------------------------------------------ # - - -log = utils.get_pylogger(__name__) - - -@utils.task_wrapper -def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. - - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. - - :param cfg: A DictConfig configuration composed by Hydra. - :return: A tuple with metrics and dict with all instantiated objects. - """ - # set seed for random number generators in pytorch, numpy and python.random - if cfg.get("seed"): - L.seed_everything(cfg.seed, workers=True) - - log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) - - log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access - model: LightningModule = hydra.utils.instantiate(cfg.model) - - log.info("Instantiating callbacks...") - callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) - - log.info("Instantiating loggers...") - logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) - - log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access - trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) - - object_dict = { - "cfg": cfg, - "datamodule": datamodule, - "model": model, - "callbacks": callbacks, - "logger": logger, - "trainer": trainer, - } - - if logger: - log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) - - if cfg.get("train"): - log.info("Starting training!") - trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) - - train_metrics = trainer.callback_metrics - - if cfg.get("test"): - log.info("Starting testing!") - ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") - ckpt_path = None - trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) - log.info(f"Best ckpt path: {ckpt_path}") - - test_metrics = trainer.callback_metrics - - # merge train and test metrics - metric_dict = {**train_metrics, **test_metrics} - - return metric_dict, object_dict - - -@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") -def main(cfg: DictConfig) -> Optional[float]: - """Main entry point for training. - - :param cfg: DictConfig configuration composed by Hydra. - :return: Optional[float] with optimized metric value. - """ - # apply extra utilities - # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) - utils.extras(cfg) - - # train the model - metric_dict, _ = train(cfg) - - # safely retrieve metric value for hydra-based hyperparameter optimization - metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) - - # return optimized metric - return metric_value - - -if __name__ == "__main__": - main() # pylint: disable=no-value-for-parameter