export to onnx

This commit is contained in:
Fangjun Kuang 2024-12-30 15:51:55 +08:00
parent 6478902108
commit 53221902cb
2 changed files with 311 additions and 5 deletions

View File

@ -93,14 +93,14 @@ class ModelWrapper(torch.nn.Module):
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
temperature: torch.Tensor,
noise_scale: torch.Tensor,
length_scale: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
x: (batch_size, num_tokens), torch.int64
x_lengths: (batch_size,), torch.int64
temperature: (1,), torch.float32
noise_scale: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
audio: (batch_size, num_samples)
@ -110,7 +110,7 @@ class ModelWrapper(torch.nn.Module):
x=x,
x_lengths=x_lengths,
n_timesteps=self.num_steps,
temperature=temperature,
temperature=noise_scale,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
@ -153,14 +153,14 @@ def main():
# encoder has a large initial length
x = torch.ones(1, 1000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0])
noise_scale = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
opset_version = 14
filename = f"model-steps-{num_steps}.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, temperature, length_scale),
(x, x_lengths, noise_scale, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "noise_scale", "length_scale"],

View File

@ -0,0 +1,306 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import datetime as dt
import re
import logging
from typing import Dict, List
import jieba
import onnxruntime as ort
import soundfile as sf
import torch
from infer import load_vocoder
from utils import intersperse
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--acoustic-model",
type=str,
required=True,
help="Path to the acoustic model",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to the tokens.txt",
)
parser.add_argument(
"--lexicon",
type=str,
required=True,
help="Path to the lexicon.txt",
)
parser.add_argument(
"--vocoder",
type=str,
required=True,
help="Path to the vocoder",
)
parser.add_argument(
"--input-text",
type=str,
required=True,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=True,
help="The filename of the wave to save the generated speech",
)
return parser
class OnnxHifiGANModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 3, x.shape
assert x.shape[0] == 1, x.shape
audio = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)[0]
# audio: (batch_size, num_samples)
return torch.from_numpy(audio)
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 2
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
metadata = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(metadata["sample_rate"])
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 2, x.shape
assert x.shape[0] == 1, x.shape
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
print("x_lengths", x_lengths)
print("x", x.shape)
noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
mel = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lengths.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: length_scale.numpy(),
},
)[0]
# mel: (batch_size, feat_dim, num_frames)
return torch.from_numpy(mel)
def read_tokens(filename: str) -> Dict[str, int]:
token2id = dict()
with open(filename, encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
idx = int(info[0])
else:
token, idx = info[0], int(info[1])
assert token not in token2id, token
token2id[token] = idx
return token2id
def read_lexicon(filename: str) -> Dict[str, List[str]]:
word2token = dict()
with open(filename, encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
w = info[0]
tokens = info[1:]
word2token[w] = tokens
return word2token
def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]:
if word in word2tokens:
return word2tokens[word]
if len(word) == 1:
return []
ans = []
for w in word:
t = convert_word_to_tokens(word2tokens, w)
ans.extend(t)
return ans
def normalize_text(text):
whiter_space_re = re.compile(r"\s+")
punctuations_re = [
(re.compile(x[0], re.IGNORECASE), x[1])
for x in [
("", ","),
("", "."),
("", "!"),
("", "?"),
("", '"'),
("", '"'),
("", "'"),
("", "'"),
("", ":"),
("", ","),
]
]
for regex, replacement in punctuations_re:
text = re.sub(regex, replacement, text)
return text
@torch.no_grad()
def main():
params = get_parser().parse_args()
logging.info(vars(params))
token2id = read_tokens(params.tokens)
word2tokens = read_lexicon(params.lexicon)
text = normalize_text(params.input_text)
seg = jieba.cut(text)
tokens = []
for s in seg:
if s in token2id:
tokens.append(s)
continue
t = convert_word_to_tokens(word2tokens, s)
if t:
tokens.extend(t)
model = OnnxModel(params.acoustic_model)
vocoder = OnnxHifiGANModel(params.vocoder)
x = []
for t in tokens:
if t in token2id:
x.append(token2id[t])
x = intersperse(x, item=token2id["_"])
x = torch.tensor(x, dtype=torch.int64).unsqueeze(0)
start_t = dt.datetime.now()
mel = model(x)
end_t = dt.datetime.now()
start_t2 = dt.datetime.now()
audio = vocoder(mel)
end_t2 = dt.datetime.now()
print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze()
sample_rate = model.sample_rate
t = (end_t - start_t).total_seconds()
t2 = (end_t2 - start_t2).total_seconds()
rtf_am = t * sample_rate / audio.shape[-1]
rtf_vocoder = t2 * sample_rate / audio.shape[-1]
print("RTF for acoustic model ", rtf_am)
print("RTF for vocoder", rtf_vocoder)
# skip denoiser
sf.write(params.output_wav, audio, sample_rate, "PCM_16")
logging.info(f"Saved to {params.output_wav}")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
"""
|HifiGAN |RTF |#Parameters (M)|
|----------|-----|---------------|
|v1 |0.818| 13.926 |
|v2 |0.101| 0.925 |
|v3 |0.118| 1.462 |
|Num steps|Acoustic Model RTF|
|---------|------------------|
| 2 | 0.039 |
| 3 | 0.047 |
| 4 | 0.071 |
| 5 | 0.076 |
| 6 | 0.103 |
"""