mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
add dataset
This commit is contained in:
parent
6fac3a3143
commit
ccd2dcc9f9
@ -7,20 +7,46 @@ import torch
|
|||||||
|
|
||||||
from icefall.utils import AttributeDict
|
from icefall.utils import AttributeDict
|
||||||
from matcha.models.matcha_tts import MatchaTTS
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
|
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||||
|
|
||||||
|
|
||||||
def get_model(params):
|
def _get_data_params() -> AttributeDict:
|
||||||
m = MatchaTTS(**params.model)
|
params = AttributeDict(
|
||||||
return m
|
{
|
||||||
|
"name": "ljspeech",
|
||||||
|
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||||
|
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||||
|
"batch_size": 32,
|
||||||
|
"num_workers": 3,
|
||||||
|
"pin_memory": False,
|
||||||
|
"cleaners": ["english_cleaners2"],
|
||||||
|
"add_blank": True,
|
||||||
|
"n_spks": 1,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"n_feats": 80,
|
||||||
|
"sample_rate": 22050,
|
||||||
|
"hop_length": 256,
|
||||||
|
"win_length": 1024,
|
||||||
|
"f_min": 0,
|
||||||
|
"f_max": 8000,
|
||||||
|
"seed": 1234,
|
||||||
|
"load_durations": False,
|
||||||
|
"data_statistics": AttributeDict(
|
||||||
|
{
|
||||||
|
"mel_mean": -5.517028331756592,
|
||||||
|
"mel_std": 2.0643954277038574,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def _get_model_params() -> AttributeDict:
|
||||||
n_feats = 80
|
n_feats = 80
|
||||||
filter_channels_dp = 256
|
filter_channels_dp = 256
|
||||||
encoder_params_p_dropout = 0.1
|
encoder_params_p_dropout = 0.1
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
|
||||||
"model": AttributeDict(
|
|
||||||
{
|
{
|
||||||
"n_vocab": 178,
|
"n_vocab": 178,
|
||||||
"n_spks": 1, # for ljspeech.
|
"n_spks": 1, # for ljspeech.
|
||||||
@ -82,8 +108,43 @@ def main():
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_params():
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"model": _get_model_params(),
|
||||||
|
"data": _get_data_params(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(params):
|
||||||
|
m = MatchaTTS(**params.model)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
data_module = TextMelDataModule(hparams=params.data)
|
||||||
|
if False:
|
||||||
|
for b in data_module.train_dataloader():
|
||||||
|
assert isinstance(b, dict)
|
||||||
|
# b.keys()
|
||||||
|
# ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations']
|
||||||
|
# x: [batch_size, 289], torch.int64
|
||||||
|
# x_lengths: [batch_size], torch.int64
|
||||||
|
# y: [batch_size, n_feats, num_frames], torch.float32
|
||||||
|
# y_lengths: [batch_size], torch.int64
|
||||||
|
# spks: None
|
||||||
|
# filepaths: list, (batch_size,)
|
||||||
|
# x_texts: list, (batch_size,)
|
||||||
|
# durations: None
|
||||||
|
|
||||||
m = get_model(params)
|
m = get_model(params)
|
||||||
print(m)
|
print(m)
|
||||||
|
|
||||||
|
@ -6,19 +6,17 @@ from math import ceil
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Tuple
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
import gdown
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import wget
|
# from omegaconf import DictConfig
|
||||||
from omegaconf import DictConfig
|
|
||||||
|
|
||||||
from matcha.utils import pylogger, rich_utils
|
# from matcha.utils import pylogger, rich_utils
|
||||||
|
|
||||||
log = pylogger.get_pylogger(__name__)
|
# log = pylogger.get_pylogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def extras(cfg: DictConfig) -> None:
|
def extras(cfg: 'DictConfig') -> None:
|
||||||
"""Applies optional utilities before the task is started.
|
"""Applies optional utilities before the task is started.
|
||||||
|
|
||||||
Utilities:
|
Utilities:
|
||||||
@ -207,6 +205,8 @@ def get_user_data_dir(appname="matcha_tts"):
|
|||||||
|
|
||||||
|
|
||||||
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
||||||
|
import gdown
|
||||||
|
import wget
|
||||||
if Path(checkpoint_path).exists():
|
if Path(checkpoint_path).exists():
|
||||||
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
||||||
print(f"[+] Model already present at {checkpoint_path}!")
|
print(f"[+] Model already present at {checkpoint_path}!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user