mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
support all hifigan versions
This commit is contained in:
parent
748557feba
commit
a67d4b9a80
@ -1,20 +1,51 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a Matcha-TTS model to ONNX.
|
||||||
|
Note that the model outputs fbank. You need to use a vocoder to convert
|
||||||
|
it to audio. See also ./export_onnx_hifigan.py
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
from inference import get_parser
|
from inference import get_parser
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from train import get_model, get_params
|
from train import get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
|
||||||
|
while len(model.metadata_props):
|
||||||
|
model.metadata_props.pop()
|
||||||
|
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = str(value)
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapper(torch.nn.Module):
|
class ModelWrapper(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model, num_steps: int = 5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.num_steps = num_steps
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -30,23 +61,24 @@ class ModelWrapper(torch.nn.Module):
|
|||||||
temperature: (1,), torch.float32
|
temperature: (1,), torch.float32
|
||||||
length_scale (1,), torch.float32
|
length_scale (1,), torch.float32
|
||||||
Returns:
|
Returns:
|
||||||
mel: (batch_size, feat_dim, num_frames)
|
audio: (batch_size, num_samples)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
mel = self.model.synthesise(
|
mel = self.model.synthesise(
|
||||||
x=x,
|
x=x,
|
||||||
x_lengths=x_lengths,
|
x_lengths=x_lengths,
|
||||||
n_timesteps=3,
|
n_timesteps=self.num_steps,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
length_scale=length_scale,
|
length_scale=length_scale,
|
||||||
)["mel"]
|
)["mel"]
|
||||||
|
|
||||||
# mel: (batch_size, feat_dim, num_frames)
|
# mel: (batch_size, feat_dim, num_frames)
|
||||||
|
|
||||||
|
# audio = self.vocoder(mel).clamp(-1, 1).squeeze(1)
|
||||||
|
|
||||||
return mel
|
return mel
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -72,20 +104,20 @@ def main():
|
|||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
|
||||||
wrapper = ModelWrapper(model)
|
for num_steps in [2, 3, 4, 5, 6]:
|
||||||
|
logging.info(f"num_steps: {num_steps}")
|
||||||
|
wrapper = ModelWrapper(model, num_steps=num_steps)
|
||||||
wrapper.eval()
|
wrapper.eval()
|
||||||
|
|
||||||
# Use a large value so the the rotary position embedding in the text
|
# Use a large value so the rotary position embedding in the text
|
||||||
# encoder has a large initial length
|
# encoder has a large initial length
|
||||||
x = torch.ones(1, 2000, dtype=torch.int64)
|
x = torch.ones(1, 1000, dtype=torch.int64)
|
||||||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||||
temperature = torch.tensor([1.0])
|
temperature = torch.tensor([1.0])
|
||||||
length_scale = torch.tensor([1.0])
|
length_scale = torch.tensor([1.0])
|
||||||
mel = wrapper(x, x_lengths, temperature, length_scale)
|
|
||||||
print("mel", mel.shape)
|
|
||||||
|
|
||||||
opset_version = 14
|
opset_version = 14
|
||||||
filename = "model.onnx"
|
filename = f"model-steps-{num_steps}.onnx"
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
wrapper,
|
wrapper,
|
||||||
(x, x_lengths, temperature, length_scale),
|
(x, x_lengths, temperature, length_scale),
|
||||||
@ -100,16 +132,21 @@ def main():
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Generate int8 quantization models")
|
meta_data = {
|
||||||
|
"model_type": "matcha-tts",
|
||||||
filename_int8 = "model.int8.onnx"
|
"language": "English",
|
||||||
quantize_dynamic(
|
"voice": "en-us",
|
||||||
model_input=filename,
|
"has_espeak": 1,
|
||||||
model_output=filename_int8,
|
"n_speakers": 1,
|
||||||
weight_type=QuantType.QInt8,
|
"sample_rate": 22050,
|
||||||
)
|
"version": 1,
|
||||||
|
"model_author": "icefall",
|
||||||
print(f"Saved to {filename} and {filename_int8}")
|
"maintainer": "k2-fsa",
|
||||||
|
"dataset": "LJ Speech",
|
||||||
|
"num_ode_steps": num_steps,
|
||||||
|
}
|
||||||
|
add_meta_data(filename=filename, meta_data=meta_data)
|
||||||
|
print(meta_data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
106
egs/ljspeech/TTS/matcha/export_onnx_hifigan.py
Executable file
106
egs/ljspeech/TTS/matcha/export_onnx_hifigan.py
Executable file
@ -0,0 +1,106 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from inference import load_vocoder
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
|
||||||
|
while len(model.metadata_props):
|
||||||
|
model.metadata_props.pop()
|
||||||
|
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = str(value)
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWrapper(torch.nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
mel: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args: :
|
||||||
|
mel: (batch_size, feat_dim, num_frames), torch.float32
|
||||||
|
Returns:
|
||||||
|
audio: (batch_size, num_samples), torch.float32
|
||||||
|
"""
|
||||||
|
audio = self.model(mel).clamp(-1, 1).squeeze(1)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
# Please go to
|
||||||
|
# https://github.com/csukuangfj/models/tree/master/hifigan
|
||||||
|
# to download the following files
|
||||||
|
model_filenames = ["./generator_v1", "./generator_v2", "./generator_v3"]
|
||||||
|
|
||||||
|
for f in model_filenames:
|
||||||
|
logging.info(f)
|
||||||
|
model = load_vocoder(f)
|
||||||
|
wrapper = ModelWrapper(model)
|
||||||
|
wrapper.eval()
|
||||||
|
num_param = sum([p.numel() for p in wrapper.parameters()])
|
||||||
|
logging.info(f"{f}: Number of parameters: {num_param}")
|
||||||
|
|
||||||
|
# Use a large value so the rotary position embedding in the text
|
||||||
|
# encoder has a large initial length
|
||||||
|
x = torch.ones(1, 80, 100000, dtype=torch.float32)
|
||||||
|
opset_version = 14
|
||||||
|
suffix = f.split("_")[-1]
|
||||||
|
filename = f"hifigan_{suffix}.onnx"
|
||||||
|
torch.onnx.export(
|
||||||
|
wrapper,
|
||||||
|
x,
|
||||||
|
filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["mel"],
|
||||||
|
output_names=["audio"],
|
||||||
|
dynamic_axes={
|
||||||
|
"mel": {0: "N", 2: "L"},
|
||||||
|
"audio": {0: "N", 1: "L"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "hifigan",
|
||||||
|
"model_filename": f.split("/")[-1],
|
||||||
|
"sample_rate": 22050,
|
||||||
|
"version": 1,
|
||||||
|
"model_author": "jik876",
|
||||||
|
"maintainer": "k2-fsa",
|
||||||
|
"dataset": "LJ Speech",
|
||||||
|
"url1": "https://github.com/jik876/hifi-gan",
|
||||||
|
"url2": "https://github.com/csukuangfj/models/tree/master/hifigan",
|
||||||
|
}
|
||||||
|
add_meta_data(filename=filename, meta_data=meta_data)
|
||||||
|
print(meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -24,5 +24,77 @@ v1 = {
|
|||||||
"fmax": 8000,
|
"fmax": 8000,
|
||||||
"fmax_loss": None,
|
"fmax_loss": None,
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
"dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# See https://drive.google.com/drive/folders/1bB1tnGIxRN-edlf6k2Rmi1gNCK9Cpcvf
|
||||||
|
v2 = {
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 16,
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.999,
|
||||||
|
"seed": 1234,
|
||||||
|
"upsample_rates": [8, 8, 2, 2],
|
||||||
|
"upsample_kernel_sizes": [16, 16, 4, 4],
|
||||||
|
"upsample_initial_channel": 128,
|
||||||
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
"resblock_initial_channel": 64,
|
||||||
|
"segment_size": 8192,
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
"sampling_rate": 22050,
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": 8000,
|
||||||
|
"fmax_loss": None,
|
||||||
|
"num_workers": 4,
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# See https://drive.google.com/drive/folders/1KKvuJTLp_gZXC8lug7H_lSXct38_3kx1
|
||||||
|
v3 = {
|
||||||
|
"resblock": "2",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 16,
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.999,
|
||||||
|
"seed": 1234,
|
||||||
|
"upsample_rates": [8, 8, 4],
|
||||||
|
"upsample_kernel_sizes": [16, 16, 8],
|
||||||
|
"upsample_initial_channel": 256,
|
||||||
|
"resblock_kernel_sizes": [3, 5, 7],
|
||||||
|
"resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]],
|
||||||
|
"resblock_initial_channel": 128,
|
||||||
|
"segment_size": 8192,
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
"sampling_rate": 22050,
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": 8000,
|
||||||
|
"fmax_loss": None,
|
||||||
|
"num_workers": 4,
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ import json
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
from matcha.hifigan.config import v1
|
from matcha.hifigan.config import v1, v2, v3
|
||||||
from matcha.hifigan.denoiser import Denoiser
|
from matcha.hifigan.denoiser import Denoiser
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from matcha.hifigan.models import Generator as HiFiGAN
|
from matcha.hifigan.models import Generator as HiFiGAN
|
||||||
@ -63,7 +63,15 @@ def get_parser():
|
|||||||
|
|
||||||
|
|
||||||
def load_vocoder(checkpoint_path):
|
def load_vocoder(checkpoint_path):
|
||||||
|
if checkpoint_path.endswith("v1"):
|
||||||
h = AttributeDict(v1)
|
h = AttributeDict(v1)
|
||||||
|
elif checkpoint_path.endswith("v2"):
|
||||||
|
h = AttributeDict(v2)
|
||||||
|
elif checkpoint_path.endswith("v3"):
|
||||||
|
h = AttributeDict(v3)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
||||||
|
|
||||||
hifigan = HiFiGAN(h).to("cpu")
|
hifigan = HiFiGAN(h).to("cpu")
|
||||||
hifigan.load_state_dict(
|
hifigan.load_state_dict(
|
||||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||||
@ -143,13 +151,12 @@ def main():
|
|||||||
denoiser = Denoiser(vocoder, mode="zeros")
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
|
|
||||||
texts = [
|
texts = [
|
||||||
"How are you doing? my friend.",
|
|
||||||
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
||||||
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Number of ODE Solver steps
|
# Number of ODE Solver steps
|
||||||
n_timesteps = 3
|
n_timesteps = 2
|
||||||
|
|
||||||
# Changes to the speaking rate
|
# Changes to the speaking rate
|
||||||
length_scale = 1.0
|
length_scale = 1.0
|
||||||
@ -203,4 +210,6 @@ if __name__ == "__main__":
|
|||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
main()
|
main()
|
||||||
|
@ -4,9 +4,48 @@ import logging
|
|||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
from inference import load_vocoder
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
from inference import load_vocoder
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxHifiGANModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in self.model.get_inputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
print("-----")
|
||||||
|
|
||||||
|
for i in self.model.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
def __call__(self, x: torch.tensor):
|
||||||
|
assert x.ndim == 3, x.shape
|
||||||
|
assert x.shape[0] == 1, x.shape
|
||||||
|
|
||||||
|
audio = self.model.run(
|
||||||
|
[self.model.get_outputs()[0].name],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: x.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return torch.from_numpy(audio)
|
||||||
|
|
||||||
|
|
||||||
class OnnxModel:
|
class OnnxModel:
|
||||||
@ -16,7 +55,7 @@ class OnnxModel:
|
|||||||
):
|
):
|
||||||
session_opts = ort.SessionOptions()
|
session_opts = ort.SessionOptions()
|
||||||
session_opts.inter_op_num_threads = 1
|
session_opts.inter_op_num_threads = 1
|
||||||
session_opts.intra_op_num_threads = 1
|
session_opts.intra_op_num_threads = 2
|
||||||
|
|
||||||
self.session_opts = session_opts
|
self.session_opts = session_opts
|
||||||
self.tokenizer = Tokenizer("./data/tokens.txt")
|
self.tokenizer = Tokenizer("./data/tokens.txt")
|
||||||
@ -58,27 +97,63 @@ class OnnxModel:
|
|||||||
return torch.from_numpy(mel)
|
return torch.from_numpy(mel)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
model = OnnxModel("./model.onnx")
|
model = OnnxModel("./model-steps-6.onnx")
|
||||||
text = "hello, how are you doing?"
|
vocoder = OnnxHifiGANModel("./hifigan_v1.onnx")
|
||||||
text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar."
|
text = "Today as always, men fall into two groups: slaves and free men."
|
||||||
|
text += "hello, how are you doing?"
|
||||||
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||||
x = torch.tensor(x, dtype=torch.int64)
|
x = torch.tensor(x, dtype=torch.int64)
|
||||||
mel = model(x)
|
|
||||||
print("mel", mel.shape) # (1, 80, 170)
|
|
||||||
|
|
||||||
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
|
start_t = dt.datetime.now()
|
||||||
audio = vocoder(mel).clamp(-1, 1)
|
mel = model(x)
|
||||||
|
end_t = dt.datetime.now()
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
audio = vocoder(mel)
|
||||||
|
|
||||||
|
start_t2 = dt.datetime.now()
|
||||||
|
audio = vocoder(mel)
|
||||||
|
end_t2 = dt.datetime.now()
|
||||||
|
|
||||||
print("audio", audio.shape) # (1, 1, num_samples)
|
print("audio", audio.shape) # (1, 1, num_samples)
|
||||||
audio = audio.squeeze()
|
audio = audio.squeeze()
|
||||||
|
|
||||||
|
t = (end_t - start_t).total_seconds()
|
||||||
|
t2 = (end_t2 - start_t2).total_seconds()
|
||||||
|
rtf = t * 22050 / audio.shape[-1]
|
||||||
|
rtf2 = t2 * 22050 / audio.shape[-1]
|
||||||
|
print("RTF", rtf)
|
||||||
|
print("RTF", rtf2)
|
||||||
|
|
||||||
# skip denoiser
|
# skip denoiser
|
||||||
sf.write("onnx.wav", audio, 22050, "PCM_16")
|
sf.write("onnx2.wav", audio, 22050, "PCM_16")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|HifiGAN |RTF |#Parameters (M)|
|
||||||
|
|----------|-----|---------------|
|
||||||
|
|v1 |0.818| 13.926 |
|
||||||
|
|v2 |0.101| 0.925 |
|
||||||
|
|v3 |0.118| 1.462 |
|
||||||
|
|
||||||
|
|Num steps|Acoustic Model RTF|
|
||||||
|
|---------|------------------|
|
||||||
|
| 2 | 0.039 |
|
||||||
|
| 3 | 0.047 |
|
||||||
|
| 4 | 0.071 |
|
||||||
|
| 5 | 0.076 |
|
||||||
|
| 6 | 0.103 |
|
||||||
|
|
||||||
|
"""
|
||||||
|
@ -741,8 +741,7 @@ def main():
|
|||||||
run(rank=0, world_size=1, args=args)
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user