icefall/egs/ljspeech/TTS/matcha/onnx_pretrained.py
2024-10-21 21:24:29 +08:00

85 lines
2.5 KiB
Python
Executable File

#!/usr/bin/env python3
import logging
import onnxruntime as ort
import torch
from tokenizer import Tokenizer
from inference import load_vocoder
import soundfile as sf
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.tokenizer = Tokenizer("./data/tokens.txt")
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 2, x.shape
assert x.shape[0] == 1, x.shape
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
print("x_lengths", x_lengths)
print("x", x.shape)
temperature = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
mel = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lengths.numpy(),
self.model.get_inputs()[2].name: temperature.numpy(),
self.model.get_inputs()[3].name: length_scale.numpy(),
},
)[0]
return torch.from_numpy(mel)
@torch.inference_mode()
def main():
model = OnnxModel("./model.onnx")
text = "hello, how are you doing?"
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar."
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.int64)
mel = model(x)
print("mel", mel.shape) # (1, 80, 170)
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
audio = vocoder(mel).clamp(-1, 1)
print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze()
# skip denoiser
sf.write("onnx.wav", audio, 22050, "PCM_16")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()