icefall/egs/vctk/TTS/vits/test_onnx.py

139 lines
4.1 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to test the exported onnx model by vits/export-onnx.py
Use the onnx model to generate a wav:
./vits/test_onnx.py \
--model-filename vits/exp/vits-epoch-1000.onnx \
--tokens data/tokens.txt
"""
import argparse
import logging
from pathlib import Path
import onnxruntime as ort
import torch
import torchaudio
from tokenizer import Tokenizer
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the onnx model.",
)
parser.add_argument(
"--speakers",
type=Path,
default=Path("data/speakers.txt"),
help="Path to speakers.txt file.",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
def __call__(
self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
) -> torch.Tensor:
"""
Args:
tokens:
A 1-D tensor of shape (1, T)
Returns:
A tensor of shape (1, T')
"""
noise_scale = torch.tensor([0.667], dtype=torch.float32)
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
alpha = torch.tensor([1.0], dtype=torch.float32)
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
self.model.get_inputs()[5].name: speaker.numpy(),
},
)[0]
return torch.from_numpy(out)
def main():
args = get_parser().parse_args()
tokenizer = Tokenizer(args.tokens)
with open(args.speakers) as f:
speaker_map = {line.strip(): i for i, line in enumerate(f)}
args.num_spks = len(speaker_map)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
speaker = torch.tensor([1], dtype=torch.int64) # (1, )
audio = model(tokens, tokens_lens, speaker) # (1, T')
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
logging.info("Saved to test_onnx.wav")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()