support all hifigan versions

This commit is contained in:
Fangjun Kuang 2024-10-28 17:51:45 +08:00
parent 748557feba
commit a67d4b9a80
6 changed files with 358 additions and 60 deletions

View File

@ -1,20 +1,51 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""
This script exports a Matcha-TTS model to ONNX.
Note that the model outputs fbank. You need to use a vocoder to convert
it to audio. See also ./export_onnx_hifigan.py
"""
import json import json
import logging import logging
from typing import Any, Dict
import onnx
import torch import torch
from inference import get_parser from inference import get_parser
from tokenizer import Tokenizer from tokenizer import Tokenizer
from train import get_model, get_params from train import get_model, get_params
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from onnxruntime.quantization import QuantType, quantize_dynamic
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class ModelWrapper(torch.nn.Module): class ModelWrapper(torch.nn.Module):
def __init__(self, model): def __init__(self, model, num_steps: int = 5):
super().__init__() super().__init__()
self.model = model self.model = model
self.num_steps = num_steps
def forward( def forward(
self, self,
@ -30,23 +61,24 @@ class ModelWrapper(torch.nn.Module):
temperature: (1,), torch.float32 temperature: (1,), torch.float32
length_scale (1,), torch.float32 length_scale (1,), torch.float32
Returns: Returns:
mel: (batch_size, feat_dim, num_frames) audio: (batch_size, num_samples)
""" """
mel = self.model.synthesise( mel = self.model.synthesise(
x=x, x=x,
x_lengths=x_lengths, x_lengths=x_lengths,
n_timesteps=3, n_timesteps=self.num_steps,
temperature=temperature, temperature=temperature,
length_scale=length_scale, length_scale=length_scale,
)["mel"] )["mel"]
# mel: (batch_size, feat_dim, num_frames) # mel: (batch_size, feat_dim, num_frames)
# audio = self.vocoder(mel).clamp(-1, 1).squeeze(1)
return mel return mel
@torch.inference_mode @torch.inference_mode()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -72,44 +104,49 @@ def main():
model = get_model(params) model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
wrapper = ModelWrapper(model) for num_steps in [2, 3, 4, 5, 6]:
wrapper.eval() logging.info(f"num_steps: {num_steps}")
wrapper = ModelWrapper(model, num_steps=num_steps)
wrapper.eval()
# Use a large value so the the rotary position embedding in the text # Use a large value so the rotary position embedding in the text
# encoder has a large initial length # encoder has a large initial length
x = torch.ones(1, 2000, dtype=torch.int64) x = torch.ones(1, 1000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0]) temperature = torch.tensor([1.0])
length_scale = torch.tensor([1.0]) length_scale = torch.tensor([1.0])
mel = wrapper(x, x_lengths, temperature, length_scale)
print("mel", mel.shape)
opset_version = 14 opset_version = 14
filename = "model.onnx" filename = f"model-steps-{num_steps}.onnx"
torch.onnx.export( torch.onnx.export(
wrapper, wrapper,
(x, x_lengths, temperature, length_scale), (x, x_lengths, temperature, length_scale),
filename, filename,
opset_version=opset_version, opset_version=opset_version,
input_names=["x", "x_length", "temperature", "length_scale"], input_names=["x", "x_length", "temperature", "length_scale"],
output_names=["mel"], output_names=["mel"],
dynamic_axes={ dynamic_axes={
"x": {0: "N", 1: "L"}, "x": {0: "N", 1: "L"},
"x_length": {0: "N"}, "x_length": {0: "N"},
"mel": {0: "N", 2: "L"}, "mel": {0: "N", 2: "L"},
}, },
) )
print("Generate int8 quantization models") meta_data = {
"model_type": "matcha-tts",
filename_int8 = "model.int8.onnx" "language": "English",
quantize_dynamic( "voice": "en-us",
model_input=filename, "has_espeak": 1,
model_output=filename_int8, "n_speakers": 1,
weight_type=QuantType.QInt8, "sample_rate": 22050,
) "version": 1,
"model_author": "icefall",
print(f"Saved to {filename} and {filename_int8}") "maintainer": "k2-fsa",
"dataset": "LJ Speech",
"num_ode_steps": num_steps,
}
add_meta_data(filename=filename, meta_data=meta_data)
print(meta_data)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,106 @@
#!/usr/bin/env python3
import logging
from typing import Any, Dict
import onnx
import torch
from inference import load_vocoder
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
mel: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
mel: (batch_size, feat_dim, num_frames), torch.float32
Returns:
audio: (batch_size, num_samples), torch.float32
"""
audio = self.model(mel).clamp(-1, 1).squeeze(1)
return audio
@torch.inference_mode()
def main():
# Please go to
# https://github.com/csukuangfj/models/tree/master/hifigan
# to download the following files
model_filenames = ["./generator_v1", "./generator_v2", "./generator_v3"]
for f in model_filenames:
logging.info(f)
model = load_vocoder(f)
wrapper = ModelWrapper(model)
wrapper.eval()
num_param = sum([p.numel() for p in wrapper.parameters()])
logging.info(f"{f}: Number of parameters: {num_param}")
# Use a large value so the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 80, 100000, dtype=torch.float32)
opset_version = 14
suffix = f.split("_")[-1]
filename = f"hifigan_{suffix}.onnx"
torch.onnx.export(
wrapper,
x,
filename,
opset_version=opset_version,
input_names=["mel"],
output_names=["audio"],
dynamic_axes={
"mel": {0: "N", 2: "L"},
"audio": {0: "N", 1: "L"},
},
)
meta_data = {
"model_type": "hifigan",
"model_filename": f.split("/")[-1],
"sample_rate": 22050,
"version": 1,
"model_author": "jik876",
"maintainer": "k2-fsa",
"dataset": "LJ Speech",
"url1": "https://github.com/jik876/hifi-gan",
"url2": "https://github.com/csukuangfj/models/tree/master/hifigan",
}
add_meta_data(filename=filename, meta_data=meta_data)
print(meta_data)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -24,5 +24,77 @@ v1 = {
"fmax": 8000, "fmax": 8000,
"fmax_loss": None, "fmax_loss": None,
"num_workers": 4, "num_workers": 4,
"dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, "dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
}
# See https://drive.google.com/drive/folders/1bB1tnGIxRN-edlf6k2Rmi1gNCK9Cpcvf
v2 = {
"resblock": "1",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [8, 8, 2, 2],
"upsample_kernel_sizes": [16, 16, 4, 4],
"upsample_initial_channel": 128,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_initial_channel": 64,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_loss": None,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
}
# See https://drive.google.com/drive/folders/1KKvuJTLp_gZXC8lug7H_lSXct38_3kx1
v3 = {
"resblock": "2",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [8, 8, 4],
"upsample_kernel_sizes": [16, 16, 8],
"upsample_initial_channel": 256,
"resblock_kernel_sizes": [3, 5, 7],
"resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]],
"resblock_initial_channel": 128,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_loss": None,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
} }

View File

@ -9,7 +9,7 @@ import json
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
import torch import torch
from matcha.hifigan.config import v1 from matcha.hifigan.config import v1, v2, v3
from matcha.hifigan.denoiser import Denoiser from matcha.hifigan.denoiser import Denoiser
from tokenizer import Tokenizer from tokenizer import Tokenizer
from matcha.hifigan.models import Generator as HiFiGAN from matcha.hifigan.models import Generator as HiFiGAN
@ -63,7 +63,15 @@ def get_parser():
def load_vocoder(checkpoint_path): def load_vocoder(checkpoint_path):
h = AttributeDict(v1) if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu") hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict( hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"] torch.load(checkpoint_path, map_location="cpu")["generator"]
@ -143,13 +151,12 @@ def main():
denoiser = Denoiser(vocoder, mode="zeros") denoiser = Denoiser(vocoder, mode="zeros")
texts = [ texts = [
"How are you doing? my friend.",
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
] ]
# Number of ODE Solver steps # Number of ODE Solver steps
n_timesteps = 3 n_timesteps = 2
# Changes to the speaking rate # Changes to the speaking rate
length_scale = 1.0 length_scale = 1.0
@ -203,4 +210,6 @@ if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main() main()

View File

@ -4,9 +4,48 @@ import logging
import onnxruntime as ort import onnxruntime as ort
import torch import torch
from tokenizer import Tokenizer from tokenizer import Tokenizer
import datetime as dt
from inference import load_vocoder
import soundfile as sf import soundfile as sf
from inference import load_vocoder
class OnnxHifiGANModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 3, x.shape
assert x.shape[0] == 1, x.shape
audio = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)[0]
return torch.from_numpy(audio)
class OnnxModel: class OnnxModel:
@ -16,7 +55,7 @@ class OnnxModel:
): ):
session_opts = ort.SessionOptions() session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1 session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1 session_opts.intra_op_num_threads = 2
self.session_opts = session_opts self.session_opts = session_opts
self.tokenizer = Tokenizer("./data/tokens.txt") self.tokenizer = Tokenizer("./data/tokens.txt")
@ -58,27 +97,63 @@ class OnnxModel:
return torch.from_numpy(mel) return torch.from_numpy(mel)
@torch.inference_mode() @torch.no_grad()
def main(): def main():
model = OnnxModel("./model.onnx") model = OnnxModel("./model-steps-6.onnx")
text = "hello, how are you doing?" vocoder = OnnxHifiGANModel("./hifigan_v1.onnx")
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." text = "Today as always, men fall into two groups: slaves and free men."
text += "hello, how are you doing?"
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.int64) x = torch.tensor(x, dtype=torch.int64)
mel = model(x)
print("mel", mel.shape) # (1, 80, 170)
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") start_t = dt.datetime.now()
audio = vocoder(mel).clamp(-1, 1) mel = model(x)
end_t = dt.datetime.now()
for i in range(3):
audio = vocoder(mel)
start_t2 = dt.datetime.now()
audio = vocoder(mel)
end_t2 = dt.datetime.now()
print("audio", audio.shape) # (1, 1, num_samples) print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze() audio = audio.squeeze()
t = (end_t - start_t).total_seconds()
t2 = (end_t2 - start_t2).total_seconds()
rtf = t * 22050 / audio.shape[-1]
rtf2 = t2 * 22050 / audio.shape[-1]
print("RTF", rtf)
print("RTF", rtf2)
# skip denoiser # skip denoiser
sf.write("onnx.wav", audio, 22050, "PCM_16") sf.write("onnx2.wav", audio, 22050, "PCM_16")
if __name__ == "__main__": if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()
"""
|HifiGAN |RTF |#Parameters (M)|
|----------|-----|---------------|
|v1 |0.818| 13.926 |
|v2 |0.101| 0.925 |
|v3 |0.118| 1.462 |
|Num steps|Acoustic Model RTF|
|---------|------------------|
| 2 | 0.039 |
| 3 | 0.047 |
| 4 | 0.071 |
| 5 | 0.076 |
| 6 | 0.103 |
"""

View File

@ -741,8 +741,7 @@ def main():
run(rank=0, world_size=1, args=args) run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main() main()