mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add dataset
This commit is contained in:
parent
6fac3a3143
commit
ccd2dcc9f9
@ -7,6 +7,119 @@ import torch
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
|
||||
|
||||
def _get_data_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||
"batch_size": 32,
|
||||
"num_workers": 3,
|
||||
"pin_memory": False,
|
||||
"cleaners": ["english_cleaners2"],
|
||||
"add_blank": True,
|
||||
"n_spks": 1,
|
||||
"n_fft": 1024,
|
||||
"n_feats": 80,
|
||||
"sample_rate": 22050,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"f_min": 0,
|
||||
"f_max": 8000,
|
||||
"seed": 1234,
|
||||
"load_durations": False,
|
||||
"data_statistics": AttributeDict(
|
||||
{
|
||||
"mel_mean": -5.517028331756592,
|
||||
"mel_std": 2.0643954277038574,
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def _get_model_params() -> AttributeDict:
|
||||
n_feats = 80
|
||||
filter_channels_dp = 256
|
||||
encoder_params_p_dropout = 0.1
|
||||
params = AttributeDict(
|
||||
{
|
||||
"n_vocab": 178,
|
||||
"n_spks": 1, # for ljspeech.
|
||||
"spk_emb_dim": 64,
|
||||
"n_feats": n_feats,
|
||||
"out_size": None, # or use 172
|
||||
"prior_loss": True,
|
||||
"use_precomputed_durations": False,
|
||||
"encoder": AttributeDict(
|
||||
{
|
||||
"encoder_type": "RoPE Encoder", # not used
|
||||
"encoder_params": AttributeDict(
|
||||
{
|
||||
"n_feats": n_feats,
|
||||
"n_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
"spk_emb_dim": 64,
|
||||
"n_spks": 1,
|
||||
"prenet": True,
|
||||
}
|
||||
),
|
||||
"duration_predictor_params": AttributeDict(
|
||||
{
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
"decoder": AttributeDict(
|
||||
{
|
||||
"channels": [256, 256],
|
||||
"dropout": 0.05,
|
||||
"attention_head_dim": 64,
|
||||
"n_blocks": 1,
|
||||
"num_mid_blocks": 2,
|
||||
"num_heads": 2,
|
||||
"act_fn": "snakebeta",
|
||||
}
|
||||
),
|
||||
"cfm": AttributeDict(
|
||||
{
|
||||
"name": "CFM",
|
||||
"solver": "euler",
|
||||
"sigma_min": 1e-4,
|
||||
}
|
||||
),
|
||||
"optimizer": AttributeDict(
|
||||
{
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.0,
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_params():
|
||||
params = AttributeDict(
|
||||
{
|
||||
"model": _get_model_params(),
|
||||
"data": _get_data_params(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def get_model(params):
|
||||
@ -15,75 +128,23 @@ def get_model(params):
|
||||
|
||||
|
||||
def main():
|
||||
n_feats = 80
|
||||
filter_channels_dp = 256
|
||||
encoder_params_p_dropout = 0.1
|
||||
params = AttributeDict(
|
||||
{
|
||||
"model": AttributeDict(
|
||||
{
|
||||
"n_vocab": 178,
|
||||
"n_spks": 1, # for ljspeech.
|
||||
"spk_emb_dim": 64,
|
||||
"n_feats": n_feats,
|
||||
"out_size": None, # or use 172
|
||||
"prior_loss": True,
|
||||
"use_precomputed_durations": False,
|
||||
"encoder": AttributeDict(
|
||||
{
|
||||
"encoder_type": "RoPE Encoder", # not used
|
||||
"encoder_params": AttributeDict(
|
||||
{
|
||||
"n_feats": n_feats,
|
||||
"n_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
"spk_emb_dim": 64,
|
||||
"n_spks": 1,
|
||||
"prenet": True,
|
||||
}
|
||||
),
|
||||
"duration_predictor_params": AttributeDict(
|
||||
{
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
"decoder": AttributeDict(
|
||||
{
|
||||
"channels": [256, 256],
|
||||
"dropout": 0.05,
|
||||
"attention_head_dim": 64,
|
||||
"n_blocks": 1,
|
||||
"num_mid_blocks": 2,
|
||||
"num_heads": 2,
|
||||
"act_fn": "snakebeta",
|
||||
}
|
||||
),
|
||||
"cfm": AttributeDict(
|
||||
{
|
||||
"name": "CFM",
|
||||
"solver": "euler",
|
||||
"sigma_min": 1e-4,
|
||||
}
|
||||
),
|
||||
"optimizer": AttributeDict(
|
||||
{
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.0,
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
params = get_params()
|
||||
|
||||
data_module = TextMelDataModule(hparams=params.data)
|
||||
if False:
|
||||
for b in data_module.train_dataloader():
|
||||
assert isinstance(b, dict)
|
||||
# b.keys()
|
||||
# ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations']
|
||||
# x: [batch_size, 289], torch.int64
|
||||
# x_lengths: [batch_size], torch.int64
|
||||
# y: [batch_size, n_feats, num_frames], torch.float32
|
||||
# y_lengths: [batch_size], torch.int64
|
||||
# spks: None
|
||||
# filepaths: list, (batch_size,)
|
||||
# x_texts: list, (batch_size,)
|
||||
# durations: None
|
||||
|
||||
m = get_model(params)
|
||||
print(m)
|
||||
|
||||
|
@ -6,19 +6,17 @@ from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import gdown
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import wget
|
||||
from omegaconf import DictConfig
|
||||
# from omegaconf import DictConfig
|
||||
|
||||
from matcha.utils import pylogger, rich_utils
|
||||
# from matcha.utils import pylogger, rich_utils
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
# log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def extras(cfg: DictConfig) -> None:
|
||||
def extras(cfg: 'DictConfig') -> None:
|
||||
"""Applies optional utilities before the task is started.
|
||||
|
||||
Utilities:
|
||||
@ -207,6 +205,8 @@ def get_user_data_dir(appname="matcha_tts"):
|
||||
|
||||
|
||||
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
||||
import gdown
|
||||
import wget
|
||||
if Path(checkpoint_path).exists():
|
||||
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
||||
print(f"[+] Model already present at {checkpoint_path}!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user