fix inference

This commit is contained in:
Fangjun Kuang 2024-10-28 19:24:09 +08:00
parent ed569a938a
commit ba4df19224

View File

@ -13,8 +13,6 @@ from matcha.hifigan.config import v1, v2, v3
from matcha.hifigan.denoiser import Denoiser from matcha.hifigan.denoiser import Denoiser
from tokenizer import Tokenizer from tokenizer import Tokenizer
from matcha.hifigan.models import Generator as HiFiGAN from matcha.hifigan.models import Generator as HiFiGAN
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import intersperse
from tqdm.auto import tqdm from tqdm.auto import tqdm
from train import get_model, get_params from train import get_model, get_params
@ -151,8 +149,13 @@ def main():
denoiser = Denoiser(vocoder, mode="zeros") denoiser = Denoiser(vocoder, mode="zeros")
texts = [ texts = [
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "The Secret Service believed that it was very doubtful that any "
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", "President would ride regularly in a vehicle with a fixed top, even "
"though transparent.",
"Today as always, men fall into two groups: slaves and free men. "
"Whoever does not have two-thirds of his day for himself, is a slave, "
"whatever he may be: a statesman, a businessman, an official, or a "
"scholar.",
] ]
# Number of ODE Solver steps # Number of ODE Solver steps
@ -164,7 +167,7 @@ def main():
# Sampling temperature # Sampling temperature
temperature = 0.667 temperature = 0.667
outputs, rtfs = [], [] rtfs = []
rtfs_w = [] rtfs_w = []
for i, text in enumerate(tqdm(texts)): for i, text in enumerate(tqdm(texts)):
output = synthesise( output = synthesise(
@ -202,7 +205,8 @@ def main():
print(f"Number of ODE steps: {n_timesteps}") print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
print( print(
f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}" "Mean RTF Waveform "
f"(incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}"
) )