support all hifigan versions

This commit is contained in:
Fangjun Kuang 2024-10-28 17:51:45 +08:00
parent 748557feba
commit a67d4b9a80
6 changed files with 358 additions and 60 deletions

View File

@ -1,20 +1,51 @@
#!/usr/bin/env python3
"""
This script exports a Matcha-TTS model to ONNX.
Note that the model outputs fbank. You need to use a vocoder to convert
it to audio. See also ./export_onnx_hifigan.py
"""
import json
import logging
from typing import Any, Dict
import onnx
import torch
from inference import get_parser
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from onnxruntime.quantization import QuantType, quantize_dynamic
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, num_steps: int = 5):
super().__init__()
self.model = model
self.num_steps = num_steps
def forward(
self,
@ -30,23 +61,24 @@ class ModelWrapper(torch.nn.Module):
temperature: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
mel: (batch_size, feat_dim, num_frames)
audio: (batch_size, num_samples)
"""
mel = self.model.synthesise(
x=x,
x_lengths=x_lengths,
n_timesteps=3,
n_timesteps=self.num_steps,
temperature=temperature,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
# audio = self.vocoder(mel).clamp(-1, 1).squeeze(1)
return mel
@torch.inference_mode
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
@ -72,44 +104,49 @@ def main():
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
wrapper = ModelWrapper(model)
wrapper.eval()
for num_steps in [2, 3, 4, 5, 6]:
logging.info(f"num_steps: {num_steps}")
wrapper = ModelWrapper(model, num_steps=num_steps)
wrapper.eval()
# Use a large value so the the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 2000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
mel = wrapper(x, x_lengths, temperature, length_scale)
print("mel", mel.shape)
# Use a large value so the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 1000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
opset_version = 14
filename = "model.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, temperature, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "temperature", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_length": {0: "N"},
"mel": {0: "N", 2: "L"},
},
)
opset_version = 14
filename = f"model-steps-{num_steps}.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, temperature, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "temperature", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_length": {0: "N"},
"mel": {0: "N", 2: "L"},
},
)
print("Generate int8 quantization models")
filename_int8 = "model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QInt8,
)
print(f"Saved to {filename} and {filename_int8}")
meta_data = {
"model_type": "matcha-tts",
"language": "English",
"voice": "en-us",
"has_espeak": 1,
"n_speakers": 1,
"sample_rate": 22050,
"version": 1,
"model_author": "icefall",
"maintainer": "k2-fsa",
"dataset": "LJ Speech",
"num_ode_steps": num_steps,
}
add_meta_data(filename=filename, meta_data=meta_data)
print(meta_data)
if __name__ == "__main__":

View File

@ -0,0 +1,106 @@
#!/usr/bin/env python3
import logging
from typing import Any, Dict
import onnx
import torch
from inference import load_vocoder
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
mel: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
mel: (batch_size, feat_dim, num_frames), torch.float32
Returns:
audio: (batch_size, num_samples), torch.float32
"""
audio = self.model(mel).clamp(-1, 1).squeeze(1)
return audio
@torch.inference_mode()
def main():
# Please go to
# https://github.com/csukuangfj/models/tree/master/hifigan
# to download the following files
model_filenames = ["./generator_v1", "./generator_v2", "./generator_v3"]
for f in model_filenames:
logging.info(f)
model = load_vocoder(f)
wrapper = ModelWrapper(model)
wrapper.eval()
num_param = sum([p.numel() for p in wrapper.parameters()])
logging.info(f"{f}: Number of parameters: {num_param}")
# Use a large value so the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 80, 100000, dtype=torch.float32)
opset_version = 14
suffix = f.split("_")[-1]
filename = f"hifigan_{suffix}.onnx"
torch.onnx.export(
wrapper,
x,
filename,
opset_version=opset_version,
input_names=["mel"],
output_names=["audio"],
dynamic_axes={
"mel": {0: "N", 2: "L"},
"audio": {0: "N", 1: "L"},
},
)
meta_data = {
"model_type": "hifigan",
"model_filename": f.split("/")[-1],
"sample_rate": 22050,
"version": 1,
"model_author": "jik876",
"maintainer": "k2-fsa",
"dataset": "LJ Speech",
"url1": "https://github.com/jik876/hifi-gan",
"url2": "https://github.com/csukuangfj/models/tree/master/hifigan",
}
add_meta_data(filename=filename, meta_data=meta_data)
print(meta_data)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -24,5 +24,77 @@ v1 = {
"fmax": 8000,
"fmax_loss": None,
"num_workers": 4,
"dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
}
# See https://drive.google.com/drive/folders/1bB1tnGIxRN-edlf6k2Rmi1gNCK9Cpcvf
v2 = {
"resblock": "1",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [8, 8, 2, 2],
"upsample_kernel_sizes": [16, 16, 4, 4],
"upsample_initial_channel": 128,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_initial_channel": 64,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_loss": None,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
}
# See https://drive.google.com/drive/folders/1KKvuJTLp_gZXC8lug7H_lSXct38_3kx1
v3 = {
"resblock": "2",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [8, 8, 4],
"upsample_kernel_sizes": [16, 16, 8],
"upsample_initial_channel": 256,
"resblock_kernel_sizes": [3, 5, 7],
"resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]],
"resblock_initial_channel": 128,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_loss": None,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
}

View File

@ -9,7 +9,7 @@ import json
import numpy as np
import soundfile as sf
import torch
from matcha.hifigan.config import v1
from matcha.hifigan.config import v1, v2, v3
from matcha.hifigan.denoiser import Denoiser
from tokenizer import Tokenizer
from matcha.hifigan.models import Generator as HiFiGAN
@ -63,7 +63,15 @@ def get_parser():
def load_vocoder(checkpoint_path):
h = AttributeDict(v1)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
@ -143,13 +151,12 @@ def main():
denoiser = Denoiser(vocoder, mode="zeros")
texts = [
"How are you doing? my friend.",
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
]
# Number of ODE Solver steps
n_timesteps = 3
n_timesteps = 2
# Changes to the speaking rate
length_scale = 1.0
@ -203,4 +210,6 @@ if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -4,9 +4,48 @@ import logging
import onnxruntime as ort
import torch
from tokenizer import Tokenizer
import datetime as dt
from inference import load_vocoder
import soundfile as sf
from inference import load_vocoder
class OnnxHifiGANModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 3, x.shape
assert x.shape[0] == 1, x.shape
audio = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)[0]
return torch.from_numpy(audio)
class OnnxModel:
@ -16,7 +55,7 @@ class OnnxModel:
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
session_opts.intra_op_num_threads = 2
self.session_opts = session_opts
self.tokenizer = Tokenizer("./data/tokens.txt")
@ -58,27 +97,63 @@ class OnnxModel:
return torch.from_numpy(mel)
@torch.inference_mode()
@torch.no_grad()
def main():
model = OnnxModel("./model.onnx")
text = "hello, how are you doing?"
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar."
model = OnnxModel("./model-steps-6.onnx")
vocoder = OnnxHifiGANModel("./hifigan_v1.onnx")
text = "Today as always, men fall into two groups: slaves and free men."
text += "hello, how are you doing?"
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.int64)
mel = model(x)
print("mel", mel.shape) # (1, 80, 170)
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
audio = vocoder(mel).clamp(-1, 1)
start_t = dt.datetime.now()
mel = model(x)
end_t = dt.datetime.now()
for i in range(3):
audio = vocoder(mel)
start_t2 = dt.datetime.now()
audio = vocoder(mel)
end_t2 = dt.datetime.now()
print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze()
t = (end_t - start_t).total_seconds()
t2 = (end_t2 - start_t2).total_seconds()
rtf = t * 22050 / audio.shape[-1]
rtf2 = t2 * 22050 / audio.shape[-1]
print("RTF", rtf)
print("RTF", rtf2)
# skip denoiser
sf.write("onnx.wav", audio, 22050, "PCM_16")
sf.write("onnx2.wav", audio, 22050, "PCM_16")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
"""
|HifiGAN |RTF |#Parameters (M)|
|----------|-----|---------------|
|v1 |0.818| 13.926 |
|v2 |0.101| 0.925 |
|v3 |0.118| 1.462 |
|Num steps|Acoustic Model RTF|
|---------|------------------|
| 2 | 0.039 |
| 3 | 0.047 |
| 4 | 0.071 |
| 5 | 0.076 |
| 6 | 0.103 |
"""

View File

@ -741,8 +741,7 @@ def main():
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()