From ccd2dcc9f9919567839af714762ca5458815ba82 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Oct 2024 22:48:35 +0800 Subject: [PATCH] add dataset --- egs/ljspeech/TTS/matcha/train.py | 199 ++++++++++++++++--------- egs/ljspeech/TTS/matcha/utils/utils.py | 12 +- 2 files changed, 136 insertions(+), 75 deletions(-) diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 1c5084204..f41ee4eae 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -7,6 +7,119 @@ import torch from icefall.utils import AttributeDict from matcha.models.matcha_tts import MatchaTTS +from matcha.data.text_mel_datamodule import TextMelDataModule + + +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "ljspeech", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + "batch_size": 32, + "num_workers": 3, + "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sample_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": AttributeDict( + { + "mel_mean": -5.517028331756592, + "mel_std": 2.0643954277038574, + } + ), + } + ) + return params + + +def _get_model_params() -> AttributeDict: + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "n_vocab": 178, + "n_spks": 1, # for ljspeech. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model": _get_model_params(), + "data": _get_data_params(), + } + ) + return params def get_model(params): @@ -15,75 +128,23 @@ def get_model(params): def main(): - n_feats = 80 - filter_channels_dp = 256 - encoder_params_p_dropout = 0.1 - params = AttributeDict( - { - "model": AttributeDict( - { - "n_vocab": 178, - "n_spks": 1, # for ljspeech. - "spk_emb_dim": 64, - "n_feats": n_feats, - "out_size": None, # or use 172 - "prior_loss": True, - "use_precomputed_durations": False, - "encoder": AttributeDict( - { - "encoder_type": "RoPE Encoder", # not used - "encoder_params": AttributeDict( - { - "n_feats": n_feats, - "n_channels": 192, - "filter_channels": 768, - "filter_channels_dp": filter_channels_dp, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - "spk_emb_dim": 64, - "n_spks": 1, - "prenet": True, - } - ), - "duration_predictor_params": AttributeDict( - { - "filter_channels_dp": filter_channels_dp, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - } - ), - } - ), - "decoder": AttributeDict( - { - "channels": [256, 256], - "dropout": 0.05, - "attention_head_dim": 64, - "n_blocks": 1, - "num_mid_blocks": 2, - "num_heads": 2, - "act_fn": "snakebeta", - } - ), - "cfm": AttributeDict( - { - "name": "CFM", - "solver": "euler", - "sigma_min": 1e-4, - } - ), - "optimizer": AttributeDict( - { - "lr": 1e-4, - "weight_decay": 0.0, - } - ), - } - ) - } - ) + params = get_params() + + data_module = TextMelDataModule(hparams=params.data) + if False: + for b in data_module.train_dataloader(): + assert isinstance(b, dict) + # b.keys() + # ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations'] + # x: [batch_size, 289], torch.int64 + # x_lengths: [batch_size], torch.int64 + # y: [batch_size, n_feats, num_frames], torch.float32 + # y_lengths: [batch_size], torch.int64 + # spks: None + # filepaths: list, (batch_size,) + # x_texts: list, (batch_size,) + # durations: None + m = get_model(params) print(m) diff --git a/egs/ljspeech/TTS/matcha/utils/utils.py b/egs/ljspeech/TTS/matcha/utils/utils.py index fc3a48ec2..bc81c316e 100644 --- a/egs/ljspeech/TTS/matcha/utils/utils.py +++ b/egs/ljspeech/TTS/matcha/utils/utils.py @@ -6,19 +6,17 @@ from math import ceil from pathlib import Path from typing import Any, Callable, Dict, Tuple -import gdown import matplotlib.pyplot as plt import numpy as np import torch -import wget -from omegaconf import DictConfig +# from omegaconf import DictConfig -from matcha.utils import pylogger, rich_utils +# from matcha.utils import pylogger, rich_utils -log = pylogger.get_pylogger(__name__) +# log = pylogger.get_pylogger(__name__) -def extras(cfg: DictConfig) -> None: +def extras(cfg: 'DictConfig') -> None: """Applies optional utilities before the task is started. Utilities: @@ -207,6 +205,8 @@ def get_user_data_dir(appname="matcha_tts"): def assert_model_downloaded(checkpoint_path, url, use_wget=True): + import gdown + import wget if Path(checkpoint_path).exists(): log.debug(f"[+] Model already present at {checkpoint_path}!") print(f"[+] Model already present at {checkpoint_path}!")