add onnx export

This commit is contained in:
Fangjun Kuang 2024-10-21 21:24:29 +08:00
parent 6a4cb112dd
commit 748557feba
6 changed files with 280 additions and 54 deletions

View File

@ -0,0 +1 @@
../local/compute_fbank_ljspeech.py

View File

@ -0,0 +1,119 @@
#!/usr/bin/env python3
import json
import logging
import torch
from inference import get_parser
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from onnxruntime.quantization import QuantType, quantize_dynamic
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
temperature: torch.Tensor,
length_scale: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
x: (batch_size, num_tokens), torch.int64
x_lengths: (batch_size,), torch.int64
temperature: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
mel: (batch_size, feat_dim, num_frames)
"""
mel = self.model.synthesise(
x=x,
x_lengths=x_lengths,
n_timesteps=3,
temperature=temperature,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
return mel
@torch.inference_mode
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
wrapper = ModelWrapper(model)
wrapper.eval()
# Use a large value so the the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 2000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
mel = wrapper(x, x_lengths, temperature, length_scale)
print("mel", mel.shape)
opset_version = 14
filename = "model.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, temperature, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "temperature", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_length": {0: "N"},
"mel": {0: "N", 2: "L"},
},
)
print("Generate int8 quantization models")
filename_int8 = "model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QInt8,
)
print(f"Saved to {filename} and {filename_int8}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -5,6 +5,7 @@ import datetime as dt
import logging import logging
from pathlib import Path from pathlib import Path
import json
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
import torch import torch
@ -29,7 +30,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=1320, default=2810,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1. Note: Epoch counts from 1.
""", """,
@ -38,7 +39,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=Path, type=Path,
default="matcha/exp-fbank", default="matcha/exp-new-3",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -51,6 +52,13 @@ def get_parser():
default="data/tokens.txt", default="data/tokens.txt",
) )
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
return parser return parser
@ -111,13 +119,21 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
logging.info(params)
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
@ -127,9 +143,9 @@ def main():
denoiser = Denoiser(vocoder, mode="zeros") denoiser = Denoiser(vocoder, mode="zeros")
texts = [ texts = [
"How are you doing, my friend", "How are you doing? my friend.",
# "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
# "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
] ]
# Number of ODE Solver steps # Number of ODE Solver steps
@ -174,7 +190,7 @@ def main():
rtfs_w.append(rtf_w) rtfs_w.append(rtf_w)
# Save the generated waveform # Save the generated waveform
save_to_folder(i, output, folder="./my-output-1320") save_to_folder(i, output, folder=f"./my-output-{params.epoch}")
print(f"Number of ODE steps: {n_timesteps}") print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python3
import logging
import onnxruntime as ort
import torch
from tokenizer import Tokenizer
from inference import load_vocoder
import soundfile as sf
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.tokenizer = Tokenizer("./data/tokens.txt")
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 2, x.shape
assert x.shape[0] == 1, x.shape
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
print("x_lengths", x_lengths)
print("x", x.shape)
temperature = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
mel = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lengths.numpy(),
self.model.get_inputs()[2].name: temperature.numpy(),
self.model.get_inputs()[3].name: length_scale.numpy(),
},
)[0]
return torch.from_numpy(mel)
@torch.inference_mode()
def main():
model = OnnxModel("./model.onnx")
text = "hello, how are you doing?"
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar."
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.int64)
mel = model(x)
print("mel", mel.shape) # (1, 80, 170)
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
audio = vocoder(mel).clamp(-1, 1)
print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze()
# skip denoiser
sf.write("onnx.wav", audio, 22050, "PCM_16")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -7,6 +7,7 @@ import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import json
import k2 import k2
import torch import torch
@ -90,6 +91,13 @@ def get_parser():
help="""Path to vocabulary.""", help="""Path to vocabulary.""",
) )
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -123,11 +131,8 @@ def get_parser():
def get_data_statistics(): def get_data_statistics():
return AttributeDict( return AttributeDict(
{ {
# "mel_mean": -5.517028331756592, # matcha-tts "mel_mean": 0,
# "mel_std": 2.0643954277038574, "mel_std": 1,
# ours
"mel_mean": -1.168782114982605,
"mel_std": 1.9283572435379028,
} }
) )
@ -138,9 +143,9 @@ def _get_data_params() -> AttributeDict:
"name": "ljspeech", "name": "ljspeech",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
"batch_size": 64, # "batch_size": 64,
"num_workers": 1, # "num_workers": 1,
"pin_memory": False, # "pin_memory": False,
"cleaners": ["english_cleaners2"], "cleaners": ["english_cleaners2"],
"add_blank": True, "add_blank": True,
"n_spks": 1, "n_spks": 1,
@ -312,7 +317,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param
tokens = batch["tokens"] tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids( tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=False, add_eos=False tokens, intersperse_blank=True, add_sos=True, add_eos=True
) )
tokens = k2.RaggedTensor(tokens) tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1) row_splits = tokens.shape.row_splits(1)
@ -619,10 +624,17 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size params.model_args.n_vocab = params.vocab_size
params.model_args.n_vocab = 178
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params) logging.info(params)
print(params) print(params)

View File

@ -24,7 +24,8 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate, CutConcatenate,
CutMix, CutMix,
@ -176,22 +177,19 @@ class LJSpeechTtsDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 sampling_rate = 22050
config = FbankConfig( config = MyFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second), hop_length=256,
frame_shift=256 / sampling_rate, # (in second) win_length=1024,
use_fft_mag=True, f_min=0,
remove_dc_offset=False, f_max=8000,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
) )
train = SpeechSynthesisDataset( train = SpeechSynthesisDataset(
return_text=False, return_text=False,
return_tokens=True, return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)), feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -229,7 +227,8 @@ class LJSpeechTtsDataModule:
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
) )
@ -239,22 +238,19 @@ class LJSpeechTtsDataModule:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 sampling_rate = 22050
config = FbankConfig( config = MyFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second), hop_length=256,
frame_shift=256 / sampling_rate, # (in second) win_length=1024,
use_fft_mag=True, f_min=0,
remove_dc_offset=False, f_max=8000,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
) )
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_text=False, return_text=False,
return_tokens=True, return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)), feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -276,7 +272,8 @@ class LJSpeechTtsDataModule:
sampler=valid_sampler, sampler=valid_sampler,
batch_size=None, batch_size=None,
num_workers=2, num_workers=2,
persistent_workers=False, persistent_workers=True,
pin_memory=True,
) )
return valid_dl return valid_dl
@ -285,22 +282,19 @@ class LJSpeechTtsDataModule:
logging.info("About to create test dataset") logging.info("About to create test dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 sampling_rate = 22050
config = FbankConfig( config = MyFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second), hop_length=256,
frame_shift=256 / sampling_rate, # (in second) win_length=1024,
use_fft_mag=True, f_min=0,
remove_dc_offset=False, f_max=8000,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
) )
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_text=False, return_text=False,
return_tokens=True, return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)), feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else: