support onnx export and testing the exported onnx model

This commit is contained in:
yaozengwei 2023-11-23 20:46:34 +08:00
parent 1ed6b4e143
commit a983dcd469
4 changed files with 410 additions and 12 deletions

View File

@ -0,0 +1,261 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exports a VITS model from PyTorch to ONNX.
Export the model to ONNX:
./vits/export-onnx.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate two files inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxModel(nn.Module):
"""A wrapper for VITS generator."""
def __init__(self, model: nn.Module):
"""
Args:
model:
A VITS generator.
frame_shift:
The frame shift in samples.
"""
super().__init__()
self.model = model
def forward(
self,
tokens: torch.Tensor,
tokens_lens: torch.Tensor,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch
Args:
tokens:
Input text token indexes (1, T_text)
tokens_lens:
Number of tokens of shape (1,)
noise_scale (float):
Noise scale parameter for flow.
noise_scale_dur (float):
Noise scale parameter for duration predictor.
alpha (float):
Alpha parameter to control the speed of generated speech.
Returns:
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
audio, _, _ = self.model.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
)
return audio
def export_model_onnx(
model: nn.Module,
model_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
The exported model has one input:
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
and it has one output:
- audio, a tensor of shape (1, T'); dtype is torch.float32
Args:
model:
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
alpha = torch.tensor([1], dtype=torch.float32)
torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},
"tokens_lens": {0: "N"},
"audio": {0: "N", 1: "T"},
},
)
meta_data = {
"model_type": "VITS",
"version": "1",
"model_author": "k2-fsa",
"comment": "VITS generator",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=model_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model = model.generator
model.to("cpu")
model.eval()
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
suffix = f"epoch-{params.epoch}"
opset_version = 13
logging.info("Exporting encoder")
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
export_model_onnx(
model,
model_filename,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -403,6 +403,7 @@ class VITSGenerator(torch.nn.Module):
"""
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
x_mask = x_mask.to(x.dtype)
g = None
if self.spks is not None:
# (B, global_channels, 1)
@ -480,6 +481,7 @@ class VITSGenerator(torch.nn.Module):
dur = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
y_mask = y_mask.to(x.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = self._generate_path(dur, attn_mask)

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to test the exported onnx model by vits/export-onnx.py
Use the onnx model to generate a wav:
./vits/test_onnx.py \
--model-filename vits/exp/vits-epoch-1000.onnx \
--tokens data/tokens.txt
"""
import argparse
import logging
import onnxruntime as ort
import torch
import torchaudio
from tokenizer import Tokenizer
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the onnx model.",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
tokens:
A 1-D tensor of shape (1, T)
Returns:
A tensor of shape (1, T')
"""
noise_scale = torch.tensor([0.667], dtype=torch.float32)
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
alpha = torch.tensor([1.0], dtype=torch.float32)
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: alpha.numpy(),
},
)[0]
return torch.from_numpy(out)
def main():
args = get_parser().parse_args()
tokenizer = Tokenizer(args.tokens)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
logging.info("Saved to test_onnx.wav")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -30,7 +30,7 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.utils import is_jit_tracing, make_pad_mask
class TextEncoder(torch.nn.Module):
@ -440,18 +440,30 @@ class RelPositionMultiheadAttention(nn.Module):
"""
(batch_size, num_heads, seq_len, n) = x.shape
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
if not is_jit_tracing():
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, seq_len, seq_len),
(batch_stride, head_stride, time_stride - n_stride, n_stride),
storage_offset=n_stride * (seq_len - 1),
)
if is_jit_tracing():
rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, seq_len, seq_len)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, seq_len, seq_len),
(batch_stride, head_stride, time_stride - n_stride, n_stride),
storage_offset=n_stride * (seq_len - 1),
)
def forward(
self,