mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add onnx export
This commit is contained in:
parent
6a4cb112dd
commit
748557feba
1
egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py
Symbolic link
1
egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py
Symbolic link
@ -0,0 +1 @@
|
||||
../local/compute_fbank_ljspeech.py
|
119
egs/ljspeech/TTS/matcha/export_onnx.py
Executable file
119
egs/ljspeech/TTS/matcha/export_onnx.py
Executable file
@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from inference import get_parser
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
|
||||
class ModelWrapper(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lengths: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
length_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args: :
|
||||
x: (batch_size, num_tokens), torch.int64
|
||||
x_lengths: (batch_size,), torch.int64
|
||||
temperature: (1,), torch.float32
|
||||
length_scale (1,), torch.float32
|
||||
Returns:
|
||||
mel: (batch_size, feat_dim, num_frames)
|
||||
|
||||
"""
|
||||
mel = self.model.synthesise(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
n_timesteps=3,
|
||||
temperature=temperature,
|
||||
length_scale=length_scale,
|
||||
)["mel"]
|
||||
|
||||
# mel: (batch_size, feat_dim, num_frames)
|
||||
|
||||
return mel
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
wrapper = ModelWrapper(model)
|
||||
wrapper.eval()
|
||||
|
||||
# Use a large value so the the rotary position embedding in the text
|
||||
# encoder has a large initial length
|
||||
x = torch.ones(1, 2000, dtype=torch.int64)
|
||||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
temperature = torch.tensor([1.0])
|
||||
length_scale = torch.tensor([1.0])
|
||||
mel = wrapper(x, x_lengths, temperature, length_scale)
|
||||
print("mel", mel.shape)
|
||||
|
||||
opset_version = 14
|
||||
filename = "model.onnx"
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
(x, x_lengths, temperature, length_scale),
|
||||
filename,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_length", "temperature", "length_scale"],
|
||||
output_names=["mel"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "L"},
|
||||
"x_length": {0: "N"},
|
||||
"mel": {0: "N", 2: "L"},
|
||||
},
|
||||
)
|
||||
|
||||
print("Generate int8 quantization models")
|
||||
|
||||
filename_int8 = "model.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=filename,
|
||||
model_output=filename_int8,
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
print(f"Saved to {filename} and {filename_int8}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -5,6 +5,7 @@ import datetime as dt
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
@ -29,7 +30,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=1320,
|
||||
default=2810,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
@ -38,7 +39,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp-fbank",
|
||||
default="matcha/exp-new-3",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -51,6 +52,13 @@ def get_parser():
|
||||
default="data/tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -111,13 +119,21 @@ def main():
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
logging.info(params)
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
@ -127,9 +143,9 @@ def main():
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
|
||||
texts = [
|
||||
"How are you doing, my friend",
|
||||
# "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
||||
# "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
||||
"How are you doing? my friend.",
|
||||
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
||||
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
||||
]
|
||||
|
||||
# Number of ODE Solver steps
|
||||
@ -174,7 +190,7 @@ def main():
|
||||
rtfs_w.append(rtf_w)
|
||||
|
||||
# Save the generated waveform
|
||||
save_to_folder(i, output, folder="./my-output-1320")
|
||||
save_to_folder(i, output, folder=f"./my-output-{params.epoch}")
|
||||
|
||||
print(f"Number of ODE steps: {n_timesteps}")
|
||||
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
|
||||
|
84
egs/ljspeech/TTS/matcha/onnx_pretrained.py
Executable file
84
egs/ljspeech/TTS/matcha/onnx_pretrained.py
Executable file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from inference import load_vocoder
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
self.tokenizer = Tokenizer("./data/tokens.txt")
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
for i in self.model.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in self.model.get_outputs():
|
||||
print(i)
|
||||
|
||||
def __call__(self, x: torch.tensor):
|
||||
assert x.ndim == 2, x.shape
|
||||
assert x.shape[0] == 1, x.shape
|
||||
|
||||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
print("x_lengths", x_lengths)
|
||||
print("x", x.shape)
|
||||
|
||||
temperature = torch.tensor([1.0], dtype=torch.float32)
|
||||
length_scale = torch.tensor([1.0], dtype=torch.float32)
|
||||
|
||||
mel = self.model.run(
|
||||
[self.model.get_outputs()[0].name],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lengths.numpy(),
|
||||
self.model.get_inputs()[2].name: temperature.numpy(),
|
||||
self.model.get_inputs()[3].name: length_scale.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(mel)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
model = OnnxModel("./model.onnx")
|
||||
text = "hello, how are you doing?"
|
||||
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar."
|
||||
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.int64)
|
||||
mel = model(x)
|
||||
print("mel", mel.shape) # (1, 80, 170)
|
||||
|
||||
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
|
||||
audio = vocoder(mel).clamp(-1, 1)
|
||||
print("audio", audio.shape) # (1, 1, num_samples)
|
||||
audio = audio.squeeze()
|
||||
|
||||
# skip denoiser
|
||||
sf.write("onnx.wav", audio, 22050, "PCM_16")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -7,6 +7,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import json
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -90,6 +91,13 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -123,11 +131,8 @@ def get_parser():
|
||||
def get_data_statistics():
|
||||
return AttributeDict(
|
||||
{
|
||||
# "mel_mean": -5.517028331756592, # matcha-tts
|
||||
# "mel_std": 2.0643954277038574,
|
||||
# ours
|
||||
"mel_mean": -1.168782114982605,
|
||||
"mel_std": 1.9283572435379028,
|
||||
"mel_mean": 0,
|
||||
"mel_std": 1,
|
||||
}
|
||||
)
|
||||
|
||||
@ -138,9 +143,9 @@ def _get_data_params() -> AttributeDict:
|
||||
"name": "ljspeech",
|
||||
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||
"batch_size": 64,
|
||||
"num_workers": 1,
|
||||
"pin_memory": False,
|
||||
# "batch_size": 64,
|
||||
# "num_workers": 1,
|
||||
# "pin_memory": False,
|
||||
"cleaners": ["english_cleaners2"],
|
||||
"add_blank": True,
|
||||
"n_spks": 1,
|
||||
@ -312,7 +317,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param
|
||||
tokens = batch["tokens"]
|
||||
|
||||
tokens = tokenizer.tokens_to_token_ids(
|
||||
tokens, intersperse_blank=True, add_sos=False, add_eos=False
|
||||
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||
)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
@ -619,10 +624,17 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.pad_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
params.model_args.n_vocab = 178
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
logging.info(params)
|
||||
print(params)
|
||||
|
@ -24,7 +24,8 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
@ -176,22 +177,19 @@ class LJSpeechTtsDataModule:
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = FbankConfig(
|
||||
config = MyFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
remove_dc_offset=False,
|
||||
preemph_coeff=0,
|
||||
low_freq=0,
|
||||
high_freq=8000,
|
||||
# should be identical to n_feats in ./train.py
|
||||
num_filters=80,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
@ -229,7 +227,8 @@ class LJSpeechTtsDataModule:
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
@ -239,22 +238,19 @@ class LJSpeechTtsDataModule:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = FbankConfig(
|
||||
config = MyFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
remove_dc_offset=False,
|
||||
preemph_coeff=0,
|
||||
low_freq=0,
|
||||
high_freq=8000,
|
||||
# should be identical to n_feats in ./train.py
|
||||
num_filters=80,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
@ -276,7 +272,8 @@ class LJSpeechTtsDataModule:
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
@ -285,22 +282,19 @@ class LJSpeechTtsDataModule:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = FbankConfig(
|
||||
config = MyFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
remove_dc_offset=False,
|
||||
preemph_coeff=0,
|
||||
low_freq=0,
|
||||
high_freq=8000,
|
||||
# should be identical to n_feats in ./train.py
|
||||
num_filters=80,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user