mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add onnx export
This commit is contained in:
parent
6a4cb112dd
commit
748557feba
1
egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py
Symbolic link
1
egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../local/compute_fbank_ljspeech.py
|
119
egs/ljspeech/TTS/matcha/export_onnx.py
Executable file
119
egs/ljspeech/TTS/matcha/export_onnx.py
Executable file
@ -0,0 +1,119 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from inference import get_parser
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWrapper(torch.nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lengths: torch.Tensor,
|
||||||
|
temperature: torch.Tensor,
|
||||||
|
length_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args: :
|
||||||
|
x: (batch_size, num_tokens), torch.int64
|
||||||
|
x_lengths: (batch_size,), torch.int64
|
||||||
|
temperature: (1,), torch.float32
|
||||||
|
length_scale (1,), torch.float32
|
||||||
|
Returns:
|
||||||
|
mel: (batch_size, feat_dim, num_frames)
|
||||||
|
|
||||||
|
"""
|
||||||
|
mel = self.model.synthesise(
|
||||||
|
x=x,
|
||||||
|
x_lengths=x_lengths,
|
||||||
|
n_timesteps=3,
|
||||||
|
temperature=temperature,
|
||||||
|
length_scale=length_scale,
|
||||||
|
)["mel"]
|
||||||
|
|
||||||
|
# mel: (batch_size, feat_dim, num_frames)
|
||||||
|
|
||||||
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.pad_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
params.model_args.n_vocab = params.vocab_size
|
||||||
|
|
||||||
|
with open(params.cmvn) as f:
|
||||||
|
stats = json.load(f)
|
||||||
|
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
|
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
|
||||||
|
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
|
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
|
||||||
|
wrapper = ModelWrapper(model)
|
||||||
|
wrapper.eval()
|
||||||
|
|
||||||
|
# Use a large value so the the rotary position embedding in the text
|
||||||
|
# encoder has a large initial length
|
||||||
|
x = torch.ones(1, 2000, dtype=torch.int64)
|
||||||
|
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||||
|
temperature = torch.tensor([1.0])
|
||||||
|
length_scale = torch.tensor([1.0])
|
||||||
|
mel = wrapper(x, x_lengths, temperature, length_scale)
|
||||||
|
print("mel", mel.shape)
|
||||||
|
|
||||||
|
opset_version = 14
|
||||||
|
filename = "model.onnx"
|
||||||
|
torch.onnx.export(
|
||||||
|
wrapper,
|
||||||
|
(x, x_lengths, temperature, length_scale),
|
||||||
|
filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_length", "temperature", "length_scale"],
|
||||||
|
output_names=["mel"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "L"},
|
||||||
|
"x_length": {0: "N"},
|
||||||
|
"mel": {0: "N", 2: "L"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Generate int8 quantization models")
|
||||||
|
|
||||||
|
filename_int8 = "model.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=filename,
|
||||||
|
model_output=filename_int8,
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Saved to {filename} and {filename_int8}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -5,6 +5,7 @@ import datetime as dt
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
@ -29,7 +30,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=1320,
|
default=2810,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 1.
|
Note: Epoch counts from 1.
|
||||||
""",
|
""",
|
||||||
@ -38,7 +39,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="matcha/exp-fbank",
|
default="matcha/exp-new-3",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -51,6 +52,13 @@ def get_parser():
|
|||||||
default="data/tokens.txt",
|
default="data/tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cmvn",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank/cmvn.json",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -111,13 +119,21 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.pad_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
params.model_args.n_vocab = params.vocab_size
|
params.model_args.n_vocab = params.vocab_size
|
||||||
|
|
||||||
|
with open(params.cmvn) as f:
|
||||||
|
stats = json.load(f)
|
||||||
|
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
|
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
|
||||||
|
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
|
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
@ -127,9 +143,9 @@ def main():
|
|||||||
denoiser = Denoiser(vocoder, mode="zeros")
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
|
|
||||||
texts = [
|
texts = [
|
||||||
"How are you doing, my friend",
|
"How are you doing? my friend.",
|
||||||
# "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
||||||
# "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Number of ODE Solver steps
|
# Number of ODE Solver steps
|
||||||
@ -174,7 +190,7 @@ def main():
|
|||||||
rtfs_w.append(rtf_w)
|
rtfs_w.append(rtf_w)
|
||||||
|
|
||||||
# Save the generated waveform
|
# Save the generated waveform
|
||||||
save_to_folder(i, output, folder="./my-output-1320")
|
save_to_folder(i, output, folder=f"./my-output-{params.epoch}")
|
||||||
|
|
||||||
print(f"Number of ODE steps: {n_timesteps}")
|
print(f"Number of ODE steps: {n_timesteps}")
|
||||||
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
|
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
|
||||||
|
84
egs/ljspeech/TTS/matcha/onnx_pretrained.py
Executable file
84
egs/ljspeech/TTS/matcha/onnx_pretrained.py
Executable file
@ -0,0 +1,84 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from inference import load_vocoder
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
self.tokenizer = Tokenizer("./data/tokens.txt")
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in self.model.get_inputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
print("-----")
|
||||||
|
|
||||||
|
for i in self.model.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
def __call__(self, x: torch.tensor):
|
||||||
|
assert x.ndim == 2, x.shape
|
||||||
|
assert x.shape[0] == 1, x.shape
|
||||||
|
|
||||||
|
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||||
|
print("x_lengths", x_lengths)
|
||||||
|
print("x", x.shape)
|
||||||
|
|
||||||
|
temperature = torch.tensor([1.0], dtype=torch.float32)
|
||||||
|
length_scale = torch.tensor([1.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
mel = self.model.run(
|
||||||
|
[self.model.get_outputs()[0].name],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: x.numpy(),
|
||||||
|
self.model.get_inputs()[1].name: x_lengths.numpy(),
|
||||||
|
self.model.get_inputs()[2].name: temperature.numpy(),
|
||||||
|
self.model.get_inputs()[3].name: length_scale.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return torch.from_numpy(mel)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
model = OnnxModel("./model.onnx")
|
||||||
|
text = "hello, how are you doing?"
|
||||||
|
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar."
|
||||||
|
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||||
|
x = torch.tensor(x, dtype=torch.int64)
|
||||||
|
mel = model(x)
|
||||||
|
print("mel", mel.shape) # (1, 80, 170)
|
||||||
|
|
||||||
|
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
|
||||||
|
audio = vocoder(mel).clamp(-1, 1)
|
||||||
|
print("audio", audio.shape) # (1, 1, num_samples)
|
||||||
|
audio = audio.squeeze()
|
||||||
|
|
||||||
|
# skip denoiser
|
||||||
|
sf.write("onnx.wav", audio, 22050, "PCM_16")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -7,6 +7,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
import json
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -90,6 +91,13 @@ def get_parser():
|
|||||||
help="""Path to vocabulary.""",
|
help="""Path to vocabulary.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cmvn",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank/cmvn.json",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -123,11 +131,8 @@ def get_parser():
|
|||||||
def get_data_statistics():
|
def get_data_statistics():
|
||||||
return AttributeDict(
|
return AttributeDict(
|
||||||
{
|
{
|
||||||
# "mel_mean": -5.517028331756592, # matcha-tts
|
"mel_mean": 0,
|
||||||
# "mel_std": 2.0643954277038574,
|
"mel_std": 1,
|
||||||
# ours
|
|
||||||
"mel_mean": -1.168782114982605,
|
|
||||||
"mel_std": 1.9283572435379028,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -138,9 +143,9 @@ def _get_data_params() -> AttributeDict:
|
|||||||
"name": "ljspeech",
|
"name": "ljspeech",
|
||||||
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||||
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||||
"batch_size": 64,
|
# "batch_size": 64,
|
||||||
"num_workers": 1,
|
# "num_workers": 1,
|
||||||
"pin_memory": False,
|
# "pin_memory": False,
|
||||||
"cleaners": ["english_cleaners2"],
|
"cleaners": ["english_cleaners2"],
|
||||||
"add_blank": True,
|
"add_blank": True,
|
||||||
"n_spks": 1,
|
"n_spks": 1,
|
||||||
@ -312,7 +317,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param
|
|||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
|
|
||||||
tokens = tokenizer.tokens_to_token_ids(
|
tokens = tokenizer.tokens_to_token_ids(
|
||||||
tokens, intersperse_blank=True, add_sos=False, add_eos=False
|
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||||
)
|
)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
@ -619,10 +624,17 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.pad_id
|
params.pad_id = tokenizer.pad_id
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
params.model_args.n_vocab = params.vocab_size
|
params.model_args.n_vocab = params.vocab_size
|
||||||
params.model_args.n_vocab = 178
|
|
||||||
|
with open(params.cmvn) as f:
|
||||||
|
stats = json.load(f)
|
||||||
|
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
|
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
|
||||||
|
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
|
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
print(params)
|
print(params)
|
||||||
|
@ -24,7 +24,8 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
@ -176,22 +177,19 @@ class LJSpeechTtsDataModule:
|
|||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = 22050
|
||||||
config = FbankConfig(
|
config = MyFbankConfig(
|
||||||
|
n_fft=1024,
|
||||||
|
n_mels=80,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
hop_length=256,
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
win_length=1024,
|
||||||
use_fft_mag=True,
|
f_min=0,
|
||||||
remove_dc_offset=False,
|
f_max=8000,
|
||||||
preemph_coeff=0,
|
|
||||||
low_freq=0,
|
|
||||||
high_freq=8000,
|
|
||||||
# should be identical to n_feats in ./train.py
|
|
||||||
num_filters=80,
|
|
||||||
)
|
)
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=False,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -229,7 +227,8 @@ class LJSpeechTtsDataModule:
|
|||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
persistent_workers=False,
|
persistent_workers=True,
|
||||||
|
pin_memory=True,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -239,22 +238,19 @@ class LJSpeechTtsDataModule:
|
|||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = 22050
|
||||||
config = FbankConfig(
|
config = MyFbankConfig(
|
||||||
|
n_fft=1024,
|
||||||
|
n_mels=80,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
hop_length=256,
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
win_length=1024,
|
||||||
use_fft_mag=True,
|
f_min=0,
|
||||||
remove_dc_offset=False,
|
f_max=8000,
|
||||||
preemph_coeff=0,
|
|
||||||
low_freq=0,
|
|
||||||
high_freq=8000,
|
|
||||||
# should be identical to n_feats in ./train.py
|
|
||||||
num_filters=80,
|
|
||||||
)
|
)
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=False,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -276,7 +272,8 @@ class LJSpeechTtsDataModule:
|
|||||||
sampler=valid_sampler,
|
sampler=valid_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=2,
|
num_workers=2,
|
||||||
persistent_workers=False,
|
persistent_workers=True,
|
||||||
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return valid_dl
|
return valid_dl
|
||||||
@ -285,22 +282,19 @@ class LJSpeechTtsDataModule:
|
|||||||
logging.info("About to create test dataset")
|
logging.info("About to create test dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = 22050
|
||||||
config = FbankConfig(
|
config = MyFbankConfig(
|
||||||
|
n_fft=1024,
|
||||||
|
n_mels=80,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
hop_length=256,
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
win_length=1024,
|
||||||
use_fft_mag=True,
|
f_min=0,
|
||||||
remove_dc_offset=False,
|
f_max=8000,
|
||||||
preemph_coeff=0,
|
|
||||||
low_freq=0,
|
|
||||||
high_freq=8000,
|
|
||||||
# should be identical to n_feats in ./train.py
|
|
||||||
num_filters=80,
|
|
||||||
)
|
)
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=False,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user