mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
export to onnx
This commit is contained in:
parent
6478902108
commit
53221902cb
@ -93,14 +93,14 @@ class ModelWrapper(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lengths: torch.Tensor,
|
x_lengths: torch.Tensor,
|
||||||
temperature: torch.Tensor,
|
noise_scale: torch.Tensor,
|
||||||
length_scale: torch.Tensor,
|
length_scale: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args: :
|
Args: :
|
||||||
x: (batch_size, num_tokens), torch.int64
|
x: (batch_size, num_tokens), torch.int64
|
||||||
x_lengths: (batch_size,), torch.int64
|
x_lengths: (batch_size,), torch.int64
|
||||||
temperature: (1,), torch.float32
|
noise_scale: (1,), torch.float32
|
||||||
length_scale (1,), torch.float32
|
length_scale (1,), torch.float32
|
||||||
Returns:
|
Returns:
|
||||||
audio: (batch_size, num_samples)
|
audio: (batch_size, num_samples)
|
||||||
@ -110,7 +110,7 @@ class ModelWrapper(torch.nn.Module):
|
|||||||
x=x,
|
x=x,
|
||||||
x_lengths=x_lengths,
|
x_lengths=x_lengths,
|
||||||
n_timesteps=self.num_steps,
|
n_timesteps=self.num_steps,
|
||||||
temperature=temperature,
|
temperature=noise_scale,
|
||||||
length_scale=length_scale,
|
length_scale=length_scale,
|
||||||
)["mel"]
|
)["mel"]
|
||||||
# mel: (batch_size, feat_dim, num_frames)
|
# mel: (batch_size, feat_dim, num_frames)
|
||||||
@ -153,14 +153,14 @@ def main():
|
|||||||
# encoder has a large initial length
|
# encoder has a large initial length
|
||||||
x = torch.ones(1, 1000, dtype=torch.int64)
|
x = torch.ones(1, 1000, dtype=torch.int64)
|
||||||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||||
temperature = torch.tensor([1.0])
|
noise_scale = torch.tensor([1.0])
|
||||||
length_scale = torch.tensor([1.0])
|
length_scale = torch.tensor([1.0])
|
||||||
|
|
||||||
opset_version = 14
|
opset_version = 14
|
||||||
filename = f"model-steps-{num_steps}.onnx"
|
filename = f"model-steps-{num_steps}.onnx"
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
wrapper,
|
wrapper,
|
||||||
(x, x_lengths, temperature, length_scale),
|
(x, x_lengths, noise_scale, length_scale),
|
||||||
filename,
|
filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
input_names=["x", "x_length", "noise_scale", "length_scale"],
|
input_names=["x", "x_length", "noise_scale", "length_scale"],
|
||||||
|
306
egs/baker_zh/TTS/matcha/onnx_pretrained.py
Executable file
306
egs/baker_zh/TTS/matcha/onnx_pretrained.py
Executable file
@ -0,0 +1,306 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
import onnxruntime as ort
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from infer import load_vocoder
|
||||||
|
from utils import intersperse
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--acoustic-model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the acoustic model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lexicon",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the lexicon.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocoder",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the vocoder",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-text",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The text to generate speech for",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-wav",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The filename of the wave to save the generated speech",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxHifiGANModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in self.model.get_inputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
print("-----")
|
||||||
|
|
||||||
|
for i in self.model.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
def __call__(self, x: torch.tensor):
|
||||||
|
assert x.ndim == 3, x.shape
|
||||||
|
assert x.shape[0] == 1, x.shape
|
||||||
|
|
||||||
|
audio = self.model.run(
|
||||||
|
[self.model.get_outputs()[0].name],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: x.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
# audio: (batch_size, num_samples)
|
||||||
|
|
||||||
|
return torch.from_numpy(audio)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 2
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||||
|
metadata = self.model.get_modelmeta().custom_metadata_map
|
||||||
|
self.sample_rate = int(metadata["sample_rate"])
|
||||||
|
|
||||||
|
for i in self.model.get_inputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
print("-----")
|
||||||
|
|
||||||
|
for i in self.model.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
def __call__(self, x: torch.tensor):
|
||||||
|
assert x.ndim == 2, x.shape
|
||||||
|
assert x.shape[0] == 1, x.shape
|
||||||
|
|
||||||
|
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||||
|
print("x_lengths", x_lengths)
|
||||||
|
print("x", x.shape)
|
||||||
|
|
||||||
|
noise_scale = torch.tensor([1.0], dtype=torch.float32)
|
||||||
|
length_scale = torch.tensor([1.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
mel = self.model.run(
|
||||||
|
[self.model.get_outputs()[0].name],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: x.numpy(),
|
||||||
|
self.model.get_inputs()[1].name: x_lengths.numpy(),
|
||||||
|
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||||
|
self.model.get_inputs()[3].name: length_scale.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
# mel: (batch_size, feat_dim, num_frames)
|
||||||
|
|
||||||
|
return torch.from_numpy(mel)
|
||||||
|
|
||||||
|
|
||||||
|
def read_tokens(filename: str) -> Dict[str, int]:
|
||||||
|
token2id = dict()
|
||||||
|
with open(filename, encoding="utf-8") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
info = line.rstrip().split()
|
||||||
|
if len(info) == 1:
|
||||||
|
# case of space
|
||||||
|
token = " "
|
||||||
|
idx = int(info[0])
|
||||||
|
else:
|
||||||
|
token, idx = info[0], int(info[1])
|
||||||
|
assert token not in token2id, token
|
||||||
|
token2id[token] = idx
|
||||||
|
return token2id
|
||||||
|
|
||||||
|
|
||||||
|
def read_lexicon(filename: str) -> Dict[str, List[str]]:
|
||||||
|
word2token = dict()
|
||||||
|
with open(filename, encoding="utf-8") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
info = line.rstrip().split()
|
||||||
|
w = info[0]
|
||||||
|
tokens = info[1:]
|
||||||
|
word2token[w] = tokens
|
||||||
|
return word2token
|
||||||
|
|
||||||
|
|
||||||
|
def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]:
|
||||||
|
if word in word2tokens:
|
||||||
|
return word2tokens[word]
|
||||||
|
|
||||||
|
if len(word) == 1:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for w in word:
|
||||||
|
t = convert_word_to_tokens(word2tokens, w)
|
||||||
|
ans.extend(t)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text(text):
|
||||||
|
whiter_space_re = re.compile(r"\s+")
|
||||||
|
|
||||||
|
punctuations_re = [
|
||||||
|
(re.compile(x[0], re.IGNORECASE), x[1])
|
||||||
|
for x in [
|
||||||
|
(",", ","),
|
||||||
|
("。", "."),
|
||||||
|
("!", "!"),
|
||||||
|
("?", "?"),
|
||||||
|
("“", '"'),
|
||||||
|
("”", '"'),
|
||||||
|
("‘", "'"),
|
||||||
|
("’", "'"),
|
||||||
|
(":", ":"),
|
||||||
|
("、", ","),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
for regex, replacement in punctuations_re:
|
||||||
|
text = re.sub(regex, replacement, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
params = get_parser().parse_args()
|
||||||
|
logging.info(vars(params))
|
||||||
|
token2id = read_tokens(params.tokens)
|
||||||
|
word2tokens = read_lexicon(params.lexicon)
|
||||||
|
|
||||||
|
text = normalize_text(params.input_text)
|
||||||
|
seg = jieba.cut(text)
|
||||||
|
tokens = []
|
||||||
|
for s in seg:
|
||||||
|
if s in token2id:
|
||||||
|
tokens.append(s)
|
||||||
|
continue
|
||||||
|
|
||||||
|
t = convert_word_to_tokens(word2tokens, s)
|
||||||
|
if t:
|
||||||
|
tokens.extend(t)
|
||||||
|
|
||||||
|
model = OnnxModel(params.acoustic_model)
|
||||||
|
vocoder = OnnxHifiGANModel(params.vocoder)
|
||||||
|
|
||||||
|
x = []
|
||||||
|
for t in tokens:
|
||||||
|
if t in token2id:
|
||||||
|
x.append(token2id[t])
|
||||||
|
|
||||||
|
x = intersperse(x, item=token2id["_"])
|
||||||
|
|
||||||
|
x = torch.tensor(x, dtype=torch.int64).unsqueeze(0)
|
||||||
|
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
mel = model(x)
|
||||||
|
end_t = dt.datetime.now()
|
||||||
|
|
||||||
|
start_t2 = dt.datetime.now()
|
||||||
|
audio = vocoder(mel)
|
||||||
|
end_t2 = dt.datetime.now()
|
||||||
|
|
||||||
|
print("audio", audio.shape) # (1, 1, num_samples)
|
||||||
|
audio = audio.squeeze()
|
||||||
|
|
||||||
|
sample_rate = model.sample_rate
|
||||||
|
|
||||||
|
t = (end_t - start_t).total_seconds()
|
||||||
|
t2 = (end_t2 - start_t2).total_seconds()
|
||||||
|
rtf_am = t * sample_rate / audio.shape[-1]
|
||||||
|
rtf_vocoder = t2 * sample_rate / audio.shape[-1]
|
||||||
|
print("RTF for acoustic model ", rtf_am)
|
||||||
|
print("RTF for vocoder", rtf_vocoder)
|
||||||
|
|
||||||
|
# skip denoiser
|
||||||
|
sf.write(params.output_wav, audio, sample_rate, "PCM_16")
|
||||||
|
logging.info(f"Saved to {params.output_wav}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|HifiGAN |RTF |#Parameters (M)|
|
||||||
|
|----------|-----|---------------|
|
||||||
|
|v1 |0.818| 13.926 |
|
||||||
|
|v2 |0.101| 0.925 |
|
||||||
|
|v3 |0.118| 1.462 |
|
||||||
|
|
||||||
|
|Num steps|Acoustic Model RTF|
|
||||||
|
|---------|------------------|
|
||||||
|
| 2 | 0.039 |
|
||||||
|
| 3 | 0.047 |
|
||||||
|
| 4 | 0.071 |
|
||||||
|
| 5 | 0.076 |
|
||||||
|
| 6 | 0.103 |
|
||||||
|
|
||||||
|
"""
|
Loading…
x
Reference in New Issue
Block a user