add onnx export

This commit is contained in:
Fangjun Kuang 2024-10-21 21:24:29 +08:00
parent 6a4cb112dd
commit 748557feba
6 changed files with 280 additions and 54 deletions

View File

@ -0,0 +1 @@
../local/compute_fbank_ljspeech.py

View File

@ -0,0 +1,119 @@
#!/usr/bin/env python3
import json
import logging
import torch
from inference import get_parser
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from onnxruntime.quantization import QuantType, quantize_dynamic
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
temperature: torch.Tensor,
length_scale: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
x: (batch_size, num_tokens), torch.int64
x_lengths: (batch_size,), torch.int64
temperature: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
mel: (batch_size, feat_dim, num_frames)
"""
mel = self.model.synthesise(
x=x,
x_lengths=x_lengths,
n_timesteps=3,
temperature=temperature,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
return mel
@torch.inference_mode
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
wrapper = ModelWrapper(model)
wrapper.eval()
# Use a large value so the the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 2000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
mel = wrapper(x, x_lengths, temperature, length_scale)
print("mel", mel.shape)
opset_version = 14
filename = "model.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, temperature, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "temperature", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_length": {0: "N"},
"mel": {0: "N", 2: "L"},
},
)
print("Generate int8 quantization models")
filename_int8 = "model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QInt8,
)
print(f"Saved to {filename} and {filename_int8}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -5,6 +5,7 @@ import datetime as dt
import logging
from pathlib import Path
import json
import numpy as np
import soundfile as sf
import torch
@ -29,7 +30,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=1320,
default=2810,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
@ -38,7 +39,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp-fbank",
default="matcha/exp-new-3",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -51,6 +52,13 @@ def get_parser():
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
return parser
@ -111,13 +119,21 @@ def main():
params = get_params()
params.update(vars(args))
logging.info(params)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
@ -127,9 +143,9 @@ def main():
denoiser = Denoiser(vocoder, mode="zeros")
texts = [
"How are you doing, my friend",
# "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
# "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
"How are you doing? my friend.",
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
]
# Number of ODE Solver steps
@ -174,7 +190,7 @@ def main():
rtfs_w.append(rtf_w)
# Save the generated waveform
save_to_folder(i, output, folder="./my-output-1320")
save_to_folder(i, output, folder=f"./my-output-{params.epoch}")
print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python3
import logging
import onnxruntime as ort
import torch
from tokenizer import Tokenizer
from inference import load_vocoder
import soundfile as sf
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.tokenizer = Tokenizer("./data/tokens.txt")
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 2, x.shape
assert x.shape[0] == 1, x.shape
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
print("x_lengths", x_lengths)
print("x", x.shape)
temperature = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
mel = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lengths.numpy(),
self.model.get_inputs()[2].name: temperature.numpy(),
self.model.get_inputs()[3].name: length_scale.numpy(),
},
)[0]
return torch.from_numpy(mel)
@torch.inference_mode()
def main():
model = OnnxModel("./model.onnx")
text = "hello, how are you doing?"
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar."
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.int64)
mel = model(x)
print("mel", mel.shape) # (1, 80, 170)
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
audio = vocoder(mel).clamp(-1, 1)
print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze()
# skip denoiser
sf.write("onnx.wav", audio, 22050, "PCM_16")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -7,6 +7,7 @@ import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import json
import k2
import torch
@ -90,6 +91,13 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--seed",
type=int,
@ -123,11 +131,8 @@ def get_parser():
def get_data_statistics():
return AttributeDict(
{
# "mel_mean": -5.517028331756592, # matcha-tts
# "mel_std": 2.0643954277038574,
# ours
"mel_mean": -1.168782114982605,
"mel_std": 1.9283572435379028,
"mel_mean": 0,
"mel_std": 1,
}
)
@ -138,9 +143,9 @@ def _get_data_params() -> AttributeDict:
"name": "ljspeech",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
"batch_size": 64,
"num_workers": 1,
"pin_memory": False,
# "batch_size": 64,
# "num_workers": 1,
# "pin_memory": False,
"cleaners": ["english_cleaners2"],
"add_blank": True,
"n_spks": 1,
@ -312,7 +317,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=False, add_eos=False
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
@ -619,10 +624,17 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
params.model_args.n_vocab = 178
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
print(params)

View File

@ -24,7 +24,8 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse import CutSet, load_manifest_lazy
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
@ -176,22 +177,19 @@ class LJSpeechTtsDataModule:
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = FbankConfig(
config = MyFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
remove_dc_offset=False,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
return_cuts=self.args.return_cuts,
)
@ -229,7 +227,8 @@ class LJSpeechTtsDataModule:
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
@ -239,22 +238,19 @@ class LJSpeechTtsDataModule:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = FbankConfig(
config = MyFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
remove_dc_offset=False,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
@ -276,7 +272,8 @@ class LJSpeechTtsDataModule:
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
persistent_workers=True,
pin_memory=True,
)
return valid_dl
@ -285,22 +282,19 @@ class LJSpeechTtsDataModule:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = FbankConfig(
config = MyFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
remove_dc_offset=False,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
return_cuts=self.args.return_cuts,
)
else: