diff --git a/egs/baker_zh/TTS/matcha/export_onnx.py b/egs/baker_zh/TTS/matcha/export_onnx.py index fc697643a..6133cc4db 100755 --- a/egs/baker_zh/TTS/matcha/export_onnx.py +++ b/egs/baker_zh/TTS/matcha/export_onnx.py @@ -93,14 +93,14 @@ class ModelWrapper(torch.nn.Module): self, x: torch.Tensor, x_lengths: torch.Tensor, - temperature: torch.Tensor, + noise_scale: torch.Tensor, length_scale: torch.Tensor, ) -> torch.Tensor: """ Args: : x: (batch_size, num_tokens), torch.int64 x_lengths: (batch_size,), torch.int64 - temperature: (1,), torch.float32 + noise_scale: (1,), torch.float32 length_scale (1,), torch.float32 Returns: audio: (batch_size, num_samples) @@ -110,7 +110,7 @@ class ModelWrapper(torch.nn.Module): x=x, x_lengths=x_lengths, n_timesteps=self.num_steps, - temperature=temperature, + temperature=noise_scale, length_scale=length_scale, )["mel"] # mel: (batch_size, feat_dim, num_frames) @@ -153,14 +153,14 @@ def main(): # encoder has a large initial length x = torch.ones(1, 1000, dtype=torch.int64) x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - temperature = torch.tensor([1.0]) + noise_scale = torch.tensor([1.0]) length_scale = torch.tensor([1.0]) opset_version = 14 filename = f"model-steps-{num_steps}.onnx" torch.onnx.export( wrapper, - (x, x_lengths, temperature, length_scale), + (x, x_lengths, noise_scale, length_scale), filename, opset_version=opset_version, input_names=["x", "x_length", "noise_scale", "length_scale"], diff --git a/egs/baker_zh/TTS/matcha/onnx_pretrained.py b/egs/baker_zh/TTS/matcha/onnx_pretrained.py new file mode 100755 index 000000000..615a3a0c8 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/onnx_pretrained.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import datetime as dt +import re +import logging +from typing import Dict, List + +import jieba +import onnxruntime as ort +import soundfile as sf +import torch +from infer import load_vocoder +from utils import intersperse + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--acoustic-model", + type=str, + required=True, + help="Path to the acoustic model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--lexicon", + type=str, + required=True, + help="Path to the lexicon.txt", + ) + + parser.add_argument( + "--vocoder", + type=str, + required=True, + help="Path to the vocoder", + ) + + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + + return parser + + +class OnnxHifiGANModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 3, x.shape + assert x.shape[0] == 1, x.shape + + audio = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + )[0] + # audio: (batch_size, num_samples) + + return torch.from_numpy(audio) + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 2 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 2, x.shape + assert x.shape[0] == 1, x.shape + + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + print("x_lengths", x_lengths) + print("x", x.shape) + + noise_scale = torch.tensor([1.0], dtype=torch.float32) + length_scale = torch.tensor([1.0], dtype=torch.float32) + + mel = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lengths.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: length_scale.numpy(), + }, + )[0] + # mel: (batch_size, feat_dim, num_frames) + + return torch.from_numpy(mel) + + +def read_tokens(filename: str) -> Dict[str, int]: + token2id = dict() + with open(filename, encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + idx = int(info[0]) + else: + token, idx = info[0], int(info[1]) + assert token not in token2id, token + token2id[token] = idx + return token2id + + +def read_lexicon(filename: str) -> Dict[str, List[str]]: + word2token = dict() + with open(filename, encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + w = info[0] + tokens = info[1:] + word2token[w] = tokens + return word2token + + +def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]: + if word in word2tokens: + return word2tokens[word] + + if len(word) == 1: + return [] + + ans = [] + for w in word: + t = convert_word_to_tokens(word2tokens, w) + ans.extend(t) + return ans + + +def normalize_text(text): + whiter_space_re = re.compile(r"\s+") + + punctuations_re = [ + (re.compile(x[0], re.IGNORECASE), x[1]) + for x in [ + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("“", '"'), + ("”", '"'), + ("‘", "'"), + ("’", "'"), + (":", ":"), + ("、", ","), + ] + ] + + for regex, replacement in punctuations_re: + text = re.sub(regex, replacement, text) + return text + + +@torch.no_grad() +def main(): + params = get_parser().parse_args() + logging.info(vars(params)) + token2id = read_tokens(params.tokens) + word2tokens = read_lexicon(params.lexicon) + + text = normalize_text(params.input_text) + seg = jieba.cut(text) + tokens = [] + for s in seg: + if s in token2id: + tokens.append(s) + continue + + t = convert_word_to_tokens(word2tokens, s) + if t: + tokens.extend(t) + + model = OnnxModel(params.acoustic_model) + vocoder = OnnxHifiGANModel(params.vocoder) + + x = [] + for t in tokens: + if t in token2id: + x.append(token2id[t]) + + x = intersperse(x, item=token2id["_"]) + + x = torch.tensor(x, dtype=torch.int64).unsqueeze(0) + + start_t = dt.datetime.now() + mel = model(x) + end_t = dt.datetime.now() + + start_t2 = dt.datetime.now() + audio = vocoder(mel) + end_t2 = dt.datetime.now() + + print("audio", audio.shape) # (1, 1, num_samples) + audio = audio.squeeze() + + sample_rate = model.sample_rate + + t = (end_t - start_t).total_seconds() + t2 = (end_t2 - start_t2).total_seconds() + rtf_am = t * sample_rate / audio.shape[-1] + rtf_vocoder = t2 * sample_rate / audio.shape[-1] + print("RTF for acoustic model ", rtf_am) + print("RTF for vocoder", rtf_vocoder) + + # skip denoiser + sf.write(params.output_wav, audio, sample_rate, "PCM_16") + logging.info(f"Saved to {params.output_wav}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() + +""" + +|HifiGAN |RTF |#Parameters (M)| +|----------|-----|---------------| +|v1 |0.818| 13.926 | +|v2 |0.101| 0.925 | +|v3 |0.118| 1.462 | + +|Num steps|Acoustic Model RTF| +|---------|------------------| +| 2 | 0.039 | +| 3 | 0.047 | +| 4 | 0.071 | +| 5 | 0.076 | +| 6 | 0.103 | + +"""