mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
85 lines
2.5 KiB
Python
Executable File
85 lines
2.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import logging
|
|
|
|
import onnxruntime as ort
|
|
import torch
|
|
from tokenizer import Tokenizer
|
|
|
|
from inference import load_vocoder
|
|
import soundfile as sf
|
|
|
|
|
|
class OnnxModel:
|
|
def __init__(
|
|
self,
|
|
filename: str,
|
|
):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 1
|
|
|
|
self.session_opts = session_opts
|
|
self.tokenizer = Tokenizer("./data/tokens.txt")
|
|
self.model = ort.InferenceSession(
|
|
filename,
|
|
sess_options=self.session_opts,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
for i in self.model.get_inputs():
|
|
print(i)
|
|
|
|
print("-----")
|
|
|
|
for i in self.model.get_outputs():
|
|
print(i)
|
|
|
|
def __call__(self, x: torch.tensor):
|
|
assert x.ndim == 2, x.shape
|
|
assert x.shape[0] == 1, x.shape
|
|
|
|
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
|
print("x_lengths", x_lengths)
|
|
print("x", x.shape)
|
|
|
|
temperature = torch.tensor([1.0], dtype=torch.float32)
|
|
length_scale = torch.tensor([1.0], dtype=torch.float32)
|
|
|
|
mel = self.model.run(
|
|
[self.model.get_outputs()[0].name],
|
|
{
|
|
self.model.get_inputs()[0].name: x.numpy(),
|
|
self.model.get_inputs()[1].name: x_lengths.numpy(),
|
|
self.model.get_inputs()[2].name: temperature.numpy(),
|
|
self.model.get_inputs()[3].name: length_scale.numpy(),
|
|
},
|
|
)[0]
|
|
|
|
return torch.from_numpy(mel)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def main():
|
|
model = OnnxModel("./model.onnx")
|
|
text = "hello, how are you doing?"
|
|
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar."
|
|
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
|
x = torch.tensor(x, dtype=torch.int64)
|
|
mel = model(x)
|
|
print("mel", mel.shape) # (1, 80, 170)
|
|
|
|
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
|
|
audio = vocoder(mel).clamp(-1, 1)
|
|
print("audio", audio.shape) # (1, 1, num_samples)
|
|
audio = audio.squeeze()
|
|
|
|
# skip denoiser
|
|
sf.write("onnx.wav", audio, 22050, "PCM_16")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
main()
|