From a67d4b9a8084364013a8e048fe2d40a184cebab1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 17:51:45 +0800 Subject: [PATCH] support all hifigan versions --- egs/ljspeech/TTS/matcha/export_onnx.py | 119 ++++++++++++------ .../TTS/matcha/export_onnx_hifigan.py | 106 ++++++++++++++++ egs/ljspeech/TTS/matcha/hifigan/config.py | 74 ++++++++++- egs/ljspeech/TTS/matcha/inference.py | 17 ++- egs/ljspeech/TTS/matcha/onnx_pretrained.py | 97 ++++++++++++-- egs/ljspeech/TTS/matcha/train.py | 5 +- 6 files changed, 358 insertions(+), 60 deletions(-) create mode 100755 egs/ljspeech/TTS/matcha/export_onnx_hifigan.py diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index c56e2da89..cf5069b11 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -1,20 +1,51 @@ #!/usr/bin/env python3 +""" +This script exports a Matcha-TTS model to ONNX. +Note that the model outputs fbank. You need to use a vocoder to convert +it to audio. See also ./export_onnx_hifigan.py +""" + import json import logging +from typing import Any, Dict +import onnx import torch from inference import get_parser from tokenizer import Tokenizer from train import get_model, get_params + from icefall.checkpoint import load_checkpoint -from onnxruntime.quantization import QuantType, quantize_dynamic + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) class ModelWrapper(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, num_steps: int = 5): super().__init__() self.model = model + self.num_steps = num_steps def forward( self, @@ -30,23 +61,24 @@ class ModelWrapper(torch.nn.Module): temperature: (1,), torch.float32 length_scale (1,), torch.float32 Returns: - mel: (batch_size, feat_dim, num_frames) + audio: (batch_size, num_samples) """ mel = self.model.synthesise( x=x, x_lengths=x_lengths, - n_timesteps=3, + n_timesteps=self.num_steps, temperature=temperature, length_scale=length_scale, )["mel"] - # mel: (batch_size, feat_dim, num_frames) + # audio = self.vocoder(mel).clamp(-1, 1).squeeze(1) + return mel -@torch.inference_mode +@torch.inference_mode() def main(): parser = get_parser() args = parser.parse_args() @@ -72,44 +104,49 @@ def main(): model = get_model(params) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - wrapper = ModelWrapper(model) - wrapper.eval() + for num_steps in [2, 3, 4, 5, 6]: + logging.info(f"num_steps: {num_steps}") + wrapper = ModelWrapper(model, num_steps=num_steps) + wrapper.eval() - # Use a large value so the the rotary position embedding in the text - # encoder has a large initial length - x = torch.ones(1, 2000, dtype=torch.int64) - x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - temperature = torch.tensor([1.0]) - length_scale = torch.tensor([1.0]) - mel = wrapper(x, x_lengths, temperature, length_scale) - print("mel", mel.shape) + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 1000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + temperature = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) - opset_version = 14 - filename = "model.onnx" - torch.onnx.export( - wrapper, - (x, x_lengths, temperature, length_scale), - filename, - opset_version=opset_version, - input_names=["x", "x_length", "temperature", "length_scale"], - output_names=["mel"], - dynamic_axes={ - "x": {0: "N", 1: "L"}, - "x_length": {0: "N"}, - "mel": {0: "N", 2: "L"}, - }, - ) + opset_version = 14 + filename = f"model-steps-{num_steps}.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, temperature, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "temperature", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) - print("Generate int8 quantization models") - - filename_int8 = "model.int8.onnx" - quantize_dynamic( - model_input=filename, - model_output=filename_int8, - weight_type=QuantType.QInt8, - ) - - print(f"Saved to {filename} and {filename_int8}") + meta_data = { + "model_type": "matcha-tts", + "language": "English", + "voice": "en-us", + "has_espeak": 1, + "n_speakers": 1, + "sample_rate": 22050, + "version": 1, + "model_author": "icefall", + "maintainer": "k2-fsa", + "dataset": "LJ Speech", + "num_ode_steps": num_steps, + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py new file mode 100755 index 000000000..3b2ebf502 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +import logging +from typing import Any, Dict + +import onnx +import torch + +from inference import load_vocoder + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + mel: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + mel: (batch_size, feat_dim, num_frames), torch.float32 + Returns: + audio: (batch_size, num_samples), torch.float32 + """ + audio = self.model(mel).clamp(-1, 1).squeeze(1) + return audio + + +@torch.inference_mode() +def main(): + # Please go to + # https://github.com/csukuangfj/models/tree/master/hifigan + # to download the following files + model_filenames = ["./generator_v1", "./generator_v2", "./generator_v3"] + + for f in model_filenames: + logging.info(f) + model = load_vocoder(f) + wrapper = ModelWrapper(model) + wrapper.eval() + num_param = sum([p.numel() for p in wrapper.parameters()]) + logging.info(f"{f}: Number of parameters: {num_param}") + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 80, 100000, dtype=torch.float32) + opset_version = 14 + suffix = f.split("_")[-1] + filename = f"hifigan_{suffix}.onnx" + torch.onnx.export( + wrapper, + x, + filename, + opset_version=opset_version, + input_names=["mel"], + output_names=["audio"], + dynamic_axes={ + "mel": {0: "N", 2: "L"}, + "audio": {0: "N", 1: "L"}, + }, + ) + + meta_data = { + "model_type": "hifigan", + "model_filename": f.split("/")[-1], + "sample_rate": 22050, + "version": 1, + "model_author": "jik876", + "maintainer": "k2-fsa", + "dataset": "LJ Speech", + "url1": "https://github.com/jik876/hifi-gan", + "url2": "https://github.com/csukuangfj/models/tree/master/hifigan", + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/hifigan/config.py b/egs/ljspeech/TTS/matcha/hifigan/config.py index b3abea9e1..ecba62fd4 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/config.py +++ b/egs/ljspeech/TTS/matcha/hifigan/config.py @@ -24,5 +24,77 @@ v1 = { "fmax": 8000, "fmax_loss": None, "num_workers": 4, - "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + +# See https://drive.google.com/drive/folders/1bB1tnGIxRN-edlf6k2Rmi1gNCK9Cpcvf +v2 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 128, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 64, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + +# See https://drive.google.com/drive/folders/1KKvuJTLp_gZXC8lug7H_lSXct38_3kx1 +v3 = { + "resblock": "2", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 4], + "upsample_kernel_sizes": [16, 16, 8], + "upsample_initial_channel": 256, + "resblock_kernel_sizes": [3, 5, 7], + "resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]], + "resblock_initial_channel": 128, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, } diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 49c9c708a..250c38f20 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -9,7 +9,7 @@ import json import numpy as np import soundfile as sf import torch -from matcha.hifigan.config import v1 +from matcha.hifigan.config import v1, v2, v3 from matcha.hifigan.denoiser import Denoiser from tokenizer import Tokenizer from matcha.hifigan.models import Generator as HiFiGAN @@ -63,7 +63,15 @@ def get_parser(): def load_vocoder(checkpoint_path): - h = AttributeDict(v1) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + hifigan = HiFiGAN(h).to("cpu") hifigan.load_state_dict( torch.load(checkpoint_path, map_location="cpu")["generator"] @@ -143,13 +151,12 @@ def main(): denoiser = Denoiser(vocoder, mode="zeros") texts = [ - "How are you doing? my friend.", "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", ] # Number of ODE Solver steps - n_timesteps = 3 + n_timesteps = 2 # Changes to the speaking rate length_scale = 1.0 @@ -203,4 +210,6 @@ if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 1a973bcff..24955e881 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -4,9 +4,48 @@ import logging import onnxruntime as ort import torch from tokenizer import Tokenizer +import datetime as dt -from inference import load_vocoder import soundfile as sf +from inference import load_vocoder + + +class OnnxHifiGANModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 3, x.shape + assert x.shape[0] == 1, x.shape + + audio = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + )[0] + + return torch.from_numpy(audio) class OnnxModel: @@ -16,7 +55,7 @@ class OnnxModel: ): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 + session_opts.intra_op_num_threads = 2 self.session_opts = session_opts self.tokenizer = Tokenizer("./data/tokens.txt") @@ -58,27 +97,63 @@ class OnnxModel: return torch.from_numpy(mel) -@torch.inference_mode() +@torch.no_grad() def main(): - model = OnnxModel("./model.onnx") - text = "hello, how are you doing?" - text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + model = OnnxModel("./model-steps-6.onnx") + vocoder = OnnxHifiGANModel("./hifigan_v1.onnx") + text = "Today as always, men fall into two groups: slaves and free men." + text += "hello, how are you doing?" x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) x = torch.tensor(x, dtype=torch.int64) - mel = model(x) - print("mel", mel.shape) # (1, 80, 170) - vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") - audio = vocoder(mel).clamp(-1, 1) + start_t = dt.datetime.now() + mel = model(x) + end_t = dt.datetime.now() + + for i in range(3): + audio = vocoder(mel) + + start_t2 = dt.datetime.now() + audio = vocoder(mel) + end_t2 = dt.datetime.now() + print("audio", audio.shape) # (1, 1, num_samples) audio = audio.squeeze() + t = (end_t - start_t).total_seconds() + t2 = (end_t2 - start_t2).total_seconds() + rtf = t * 22050 / audio.shape[-1] + rtf2 = t2 * 22050 / audio.shape[-1] + print("RTF", rtf) + print("RTF", rtf2) + # skip denoiser - sf.write("onnx.wav", audio, 22050, "PCM_16") + sf.write("onnx2.wav", audio, 22050, "PCM_16") if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() + +""" + +|HifiGAN |RTF |#Parameters (M)| +|----------|-----|---------------| +|v1 |0.818| 13.926 | +|v2 |0.101| 0.925 | +|v3 |0.118| 1.462 | + +|Num steps|Acoustic Model RTF| +|---------|------------------| +| 2 | 0.039 | +| 3 | 0.047 | +| 4 | 0.071 | +| 5 | 0.076 | +| 6 | 0.103 | + +""" diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index bb9307864..747292197 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -741,8 +741,7 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main()