add dataset

This commit is contained in:
Fangjun Kuang 2024-10-15 22:48:35 +08:00
parent 6fac3a3143
commit ccd2dcc9f9
2 changed files with 136 additions and 75 deletions

View File

@ -7,6 +7,119 @@ import torch
from icefall.utils import AttributeDict
from matcha.models.matcha_tts import MatchaTTS
from matcha.data.text_mel_datamodule import TextMelDataModule
def _get_data_params() -> AttributeDict:
params = AttributeDict(
{
"name": "ljspeech",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
"batch_size": 32,
"num_workers": 3,
"pin_memory": False,
"cleaners": ["english_cleaners2"],
"add_blank": True,
"n_spks": 1,
"n_fft": 1024,
"n_feats": 80,
"sample_rate": 22050,
"hop_length": 256,
"win_length": 1024,
"f_min": 0,
"f_max": 8000,
"seed": 1234,
"load_durations": False,
"data_statistics": AttributeDict(
{
"mel_mean": -5.517028331756592,
"mel_std": 2.0643954277038574,
}
),
}
)
return params
def _get_model_params() -> AttributeDict:
n_feats = 80
filter_channels_dp = 256
encoder_params_p_dropout = 0.1
params = AttributeDict(
{
"n_vocab": 178,
"n_spks": 1, # for ljspeech.
"spk_emb_dim": 64,
"n_feats": n_feats,
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
"encoder_params": AttributeDict(
{
"n_feats": n_feats,
"n_channels": 192,
"filter_channels": 768,
"filter_channels_dp": filter_channels_dp,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
"spk_emb_dim": 64,
"n_spks": 1,
"prenet": True,
}
),
"duration_predictor_params": AttributeDict(
{
"filter_channels_dp": filter_channels_dp,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
}
),
}
),
"decoder": AttributeDict(
{
"channels": [256, 256],
"dropout": 0.05,
"attention_head_dim": 64,
"n_blocks": 1,
"num_mid_blocks": 2,
"num_heads": 2,
"act_fn": "snakebeta",
}
),
"cfm": AttributeDict(
{
"name": "CFM",
"solver": "euler",
"sigma_min": 1e-4,
}
),
"optimizer": AttributeDict(
{
"lr": 1e-4,
"weight_decay": 0.0,
}
),
}
)
return params
def get_params():
params = AttributeDict(
{
"model": _get_model_params(),
"data": _get_data_params(),
}
)
return params
def get_model(params):
@ -15,75 +128,23 @@ def get_model(params):
def main():
n_feats = 80
filter_channels_dp = 256
encoder_params_p_dropout = 0.1
params = AttributeDict(
{
"model": AttributeDict(
{
"n_vocab": 178,
"n_spks": 1, # for ljspeech.
"spk_emb_dim": 64,
"n_feats": n_feats,
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
"encoder_params": AttributeDict(
{
"n_feats": n_feats,
"n_channels": 192,
"filter_channels": 768,
"filter_channels_dp": filter_channels_dp,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
"spk_emb_dim": 64,
"n_spks": 1,
"prenet": True,
}
),
"duration_predictor_params": AttributeDict(
{
"filter_channels_dp": filter_channels_dp,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
}
),
}
),
"decoder": AttributeDict(
{
"channels": [256, 256],
"dropout": 0.05,
"attention_head_dim": 64,
"n_blocks": 1,
"num_mid_blocks": 2,
"num_heads": 2,
"act_fn": "snakebeta",
}
),
"cfm": AttributeDict(
{
"name": "CFM",
"solver": "euler",
"sigma_min": 1e-4,
}
),
"optimizer": AttributeDict(
{
"lr": 1e-4,
"weight_decay": 0.0,
}
),
}
)
}
)
params = get_params()
data_module = TextMelDataModule(hparams=params.data)
if False:
for b in data_module.train_dataloader():
assert isinstance(b, dict)
# b.keys()
# ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations']
# x: [batch_size, 289], torch.int64
# x_lengths: [batch_size], torch.int64
# y: [batch_size, n_feats, num_frames], torch.float32
# y_lengths: [batch_size], torch.int64
# spks: None
# filepaths: list, (batch_size,)
# x_texts: list, (batch_size,)
# durations: None
m = get_model(params)
print(m)

View File

@ -6,19 +6,17 @@ from math import ceil
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import gdown
import matplotlib.pyplot as plt
import numpy as np
import torch
import wget
from omegaconf import DictConfig
# from omegaconf import DictConfig
from matcha.utils import pylogger, rich_utils
# from matcha.utils import pylogger, rich_utils
log = pylogger.get_pylogger(__name__)
# log = pylogger.get_pylogger(__name__)
def extras(cfg: DictConfig) -> None:
def extras(cfg: 'DictConfig') -> None:
"""Applies optional utilities before the task is started.
Utilities:
@ -207,6 +205,8 @@ def get_user_data_dir(appname="matcha_tts"):
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
import gdown
import wget
if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!")
print(f"[+] Model already present at {checkpoint_path}!")