From 647890210868504e5508107d7ff68ded265c2f10 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 27 Dec 2024 19:16:36 +0800 Subject: [PATCH] generate lexicon and export onnx models --- egs/baker_zh/TTS/matcha/export_onnx.py | 200 ++++++++++++++++++ .../TTS/matcha/export_onnx_hifigan.py | 1 + egs/baker_zh/TTS/matcha/generate_lexicon.py | 42 ++++ 3 files changed, 243 insertions(+) create mode 100755 egs/baker_zh/TTS/matcha/export_onnx.py create mode 120000 egs/baker_zh/TTS/matcha/export_onnx_hifigan.py create mode 100755 egs/baker_zh/TTS/matcha/generate_lexicon.py diff --git a/egs/baker_zh/TTS/matcha/export_onnx.py b/egs/baker_zh/TTS/matcha/export_onnx.py new file mode 100755 index 000000000..fc697643a --- /dev/null +++ b/egs/baker_zh/TTS/matcha/export_onnx.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script exports a Matcha-TTS model to ONNX. +Note that the model outputs fbank. You need to use a vocoder to convert +it to audio. See also ./export_onnx_hifigan.py +""" + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +import onnx +import torch +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=2000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp-new-3", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model, num_steps: int = 5): + super().__init__() + self.model = model + self.num_steps = num_steps + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + temperature: torch.Tensor, + length_scale: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + x: (batch_size, num_tokens), torch.int64 + x_lengths: (batch_size,), torch.int64 + temperature: (1,), torch.float32 + length_scale (1,), torch.float32 + Returns: + audio: (batch_size, num_samples) + + """ + mel = self.model.synthesise( + x=x, + x_lengths=x_lengths, + n_timesteps=self.num_steps, + temperature=temperature, + length_scale=length_scale, + )["mel"] + # mel: (batch_size, feat_dim, num_frames) + + return mel + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + for num_steps in [2, 3, 4, 5, 6]: + logging.info(f"num_steps: {num_steps}") + wrapper = ModelWrapper(model, num_steps=num_steps) + wrapper.eval() + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 1000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + temperature = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) + + opset_version = 14 + filename = f"model-steps-{num_steps}.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, temperature, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "noise_scale", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) + + meta_data = { + "model_type": "matcha-tts", + "language": "Chinese", + "has_espeak": 0, + "n_speakers": 1, + "jieba": 1, + "sample_rate": 22050, + "version": 1, + "pad_id": params.pad_id, + "model_author": "icefall", + "maintainer": "k2-fsa", + "dataset": "baker-zh", + "use_eos_bos": 1, + "dataset_url": "https://www.data-baker.com/open_source.html", + "dataset_comment": "The dataset is for non-commercial use only.", + "num_ode_steps": num_steps, + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py new file mode 120000 index 000000000..d0b8af15b --- /dev/null +++ b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/export_onnx_hifigan.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/generate_lexicon.py b/egs/baker_zh/TTS/matcha/generate_lexicon.py new file mode 100755 index 000000000..f26f28e91 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/generate_lexicon.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +import jieba +from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict +from tokenizer import Tokenizer + +load_phrases_dict( + { + "行长": [["hang2"], ["zhang3"]], + "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], + } +) + + +def main(): + filename = "lexicon.txt" + tokens = "./data/tokens.txt" + tokenizer = Tokenizer(tokens) + + word_dict = pinyin_dict.pinyin_dict + phrases = phrases_dict.phrases_dict + + i = 0 + with open(filename, "w", encoding="utf-8") as f: + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] + + f.write(f"{w} {tokens}\n") + + for key in phrases: + tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True) + tokens = " ".join(tokens) + + f.write(f"{key} {tokens}\n") + + +if __name__ == "__main__": + main()