mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
export to onnx
This commit is contained in:
parent
6478902108
commit
53221902cb
@ -93,14 +93,14 @@ class ModelWrapper(torch.nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lengths: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
noise_scale: torch.Tensor,
|
||||
length_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args: :
|
||||
x: (batch_size, num_tokens), torch.int64
|
||||
x_lengths: (batch_size,), torch.int64
|
||||
temperature: (1,), torch.float32
|
||||
noise_scale: (1,), torch.float32
|
||||
length_scale (1,), torch.float32
|
||||
Returns:
|
||||
audio: (batch_size, num_samples)
|
||||
@ -110,7 +110,7 @@ class ModelWrapper(torch.nn.Module):
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
n_timesteps=self.num_steps,
|
||||
temperature=temperature,
|
||||
temperature=noise_scale,
|
||||
length_scale=length_scale,
|
||||
)["mel"]
|
||||
# mel: (batch_size, feat_dim, num_frames)
|
||||
@ -153,14 +153,14 @@ def main():
|
||||
# encoder has a large initial length
|
||||
x = torch.ones(1, 1000, dtype=torch.int64)
|
||||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
temperature = torch.tensor([1.0])
|
||||
noise_scale = torch.tensor([1.0])
|
||||
length_scale = torch.tensor([1.0])
|
||||
|
||||
opset_version = 14
|
||||
filename = f"model-steps-{num_steps}.onnx"
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
(x, x_lengths, temperature, length_scale),
|
||||
(x, x_lengths, noise_scale, length_scale),
|
||||
filename,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_length", "noise_scale", "length_scale"],
|
||||
|
306
egs/baker_zh/TTS/matcha/onnx_pretrained.py
Executable file
306
egs/baker_zh/TTS/matcha/onnx_pretrained.py
Executable file
@ -0,0 +1,306 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import jieba
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from infer import load_vocoder
|
||||
from utils import intersperse
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--acoustic-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the acoustic model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lexicon",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the lexicon.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the vocoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxHifiGANModel:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
for i in self.model.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in self.model.get_outputs():
|
||||
print(i)
|
||||
|
||||
def __call__(self, x: torch.tensor):
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x.shape[0] == 1, x.shape
|
||||
|
||||
audio = self.model.run(
|
||||
[self.model.get_outputs()[0].name],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
},
|
||||
)[0]
|
||||
# audio: (batch_size, num_samples)
|
||||
|
||||
return torch.from_numpy(audio)
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 2
|
||||
|
||||
self.session_opts = session_opts
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||
metadata = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sample_rate = int(metadata["sample_rate"])
|
||||
|
||||
for i in self.model.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in self.model.get_outputs():
|
||||
print(i)
|
||||
|
||||
def __call__(self, x: torch.tensor):
|
||||
assert x.ndim == 2, x.shape
|
||||
assert x.shape[0] == 1, x.shape
|
||||
|
||||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
print("x_lengths", x_lengths)
|
||||
print("x", x.shape)
|
||||
|
||||
noise_scale = torch.tensor([1.0], dtype=torch.float32)
|
||||
length_scale = torch.tensor([1.0], dtype=torch.float32)
|
||||
|
||||
mel = self.model.run(
|
||||
[self.model.get_outputs()[0].name],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lengths.numpy(),
|
||||
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||
self.model.get_inputs()[3].name: length_scale.numpy(),
|
||||
},
|
||||
)[0]
|
||||
# mel: (batch_size, feat_dim, num_frames)
|
||||
|
||||
return torch.from_numpy(mel)
|
||||
|
||||
|
||||
def read_tokens(filename: str) -> Dict[str, int]:
|
||||
token2id = dict()
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
info = line.rstrip().split()
|
||||
if len(info) == 1:
|
||||
# case of space
|
||||
token = " "
|
||||
idx = int(info[0])
|
||||
else:
|
||||
token, idx = info[0], int(info[1])
|
||||
assert token not in token2id, token
|
||||
token2id[token] = idx
|
||||
return token2id
|
||||
|
||||
|
||||
def read_lexicon(filename: str) -> Dict[str, List[str]]:
|
||||
word2token = dict()
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
info = line.rstrip().split()
|
||||
w = info[0]
|
||||
tokens = info[1:]
|
||||
word2token[w] = tokens
|
||||
return word2token
|
||||
|
||||
|
||||
def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]:
|
||||
if word in word2tokens:
|
||||
return word2tokens[word]
|
||||
|
||||
if len(word) == 1:
|
||||
return []
|
||||
|
||||
ans = []
|
||||
for w in word:
|
||||
t = convert_word_to_tokens(word2tokens, w)
|
||||
ans.extend(t)
|
||||
return ans
|
||||
|
||||
|
||||
def normalize_text(text):
|
||||
whiter_space_re = re.compile(r"\s+")
|
||||
|
||||
punctuations_re = [
|
||||
(re.compile(x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
(",", ","),
|
||||
("。", "."),
|
||||
("!", "!"),
|
||||
("?", "?"),
|
||||
("“", '"'),
|
||||
("”", '"'),
|
||||
("‘", "'"),
|
||||
("’", "'"),
|
||||
(":", ":"),
|
||||
("、", ","),
|
||||
]
|
||||
]
|
||||
|
||||
for regex, replacement in punctuations_re:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
params = get_parser().parse_args()
|
||||
logging.info(vars(params))
|
||||
token2id = read_tokens(params.tokens)
|
||||
word2tokens = read_lexicon(params.lexicon)
|
||||
|
||||
text = normalize_text(params.input_text)
|
||||
seg = jieba.cut(text)
|
||||
tokens = []
|
||||
for s in seg:
|
||||
if s in token2id:
|
||||
tokens.append(s)
|
||||
continue
|
||||
|
||||
t = convert_word_to_tokens(word2tokens, s)
|
||||
if t:
|
||||
tokens.extend(t)
|
||||
|
||||
model = OnnxModel(params.acoustic_model)
|
||||
vocoder = OnnxHifiGANModel(params.vocoder)
|
||||
|
||||
x = []
|
||||
for t in tokens:
|
||||
if t in token2id:
|
||||
x.append(token2id[t])
|
||||
|
||||
x = intersperse(x, item=token2id["_"])
|
||||
|
||||
x = torch.tensor(x, dtype=torch.int64).unsqueeze(0)
|
||||
|
||||
start_t = dt.datetime.now()
|
||||
mel = model(x)
|
||||
end_t = dt.datetime.now()
|
||||
|
||||
start_t2 = dt.datetime.now()
|
||||
audio = vocoder(mel)
|
||||
end_t2 = dt.datetime.now()
|
||||
|
||||
print("audio", audio.shape) # (1, 1, num_samples)
|
||||
audio = audio.squeeze()
|
||||
|
||||
sample_rate = model.sample_rate
|
||||
|
||||
t = (end_t - start_t).total_seconds()
|
||||
t2 = (end_t2 - start_t2).total_seconds()
|
||||
rtf_am = t * sample_rate / audio.shape[-1]
|
||||
rtf_vocoder = t2 * sample_rate / audio.shape[-1]
|
||||
print("RTF for acoustic model ", rtf_am)
|
||||
print("RTF for vocoder", rtf_vocoder)
|
||||
|
||||
# skip denoiser
|
||||
sf.write(params.output_wav, audio, sample_rate, "PCM_16")
|
||||
logging.info(f"Saved to {params.output_wav}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
||||
"""
|
||||
|
||||
|HifiGAN |RTF |#Parameters (M)|
|
||||
|----------|-----|---------------|
|
||||
|v1 |0.818| 13.926 |
|
||||
|v2 |0.101| 0.925 |
|
||||
|v3 |0.118| 1.462 |
|
||||
|
||||
|Num steps|Acoustic Model RTF|
|
||||
|---------|------------------|
|
||||
| 2 | 0.039 |
|
||||
| 3 | 0.047 |
|
||||
| 4 | 0.071 |
|
||||
| 5 | 0.076 |
|
||||
| 6 | 0.103 |
|
||||
|
||||
"""
|
Loading…
x
Reference in New Issue
Block a user