icefall/egs/ljspeech/TTS/matcha/export_onnx.py
2024-10-21 21:24:29 +08:00

120 lines
3.3 KiB
Python
Executable File

#!/usr/bin/env python3
import json
import logging
import torch
from inference import get_parser
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from onnxruntime.quantization import QuantType, quantize_dynamic
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
temperature: torch.Tensor,
length_scale: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
x: (batch_size, num_tokens), torch.int64
x_lengths: (batch_size,), torch.int64
temperature: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
mel: (batch_size, feat_dim, num_frames)
"""
mel = self.model.synthesise(
x=x,
x_lengths=x_lengths,
n_timesteps=3,
temperature=temperature,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
return mel
@torch.inference_mode
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
wrapper = ModelWrapper(model)
wrapper.eval()
# Use a large value so the the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 2000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
mel = wrapper(x, x_lengths, temperature, length_scale)
print("mel", mel.shape)
opset_version = 14
filename = "model.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, temperature, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "temperature", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_length": {0: "N"},
"mel": {0: "N", 2: "L"},
},
)
print("Generate int8 quantization models")
filename_int8 = "model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QInt8,
)
print(f"Saved to {filename} and {filename_int8}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()