This commit is contained in:
zr_jin 2024-11-05 12:12:31 +08:00
parent 06c2993dfc
commit b9e3d24c4d
2 changed files with 44 additions and 71 deletions

View File

@ -266,10 +266,10 @@ def main():
raise ValueError(f"{params.vocoder} does not exist") raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder) vocoder = load_vocoder(params.vocoder)
vocoder = vocoder.to(device) vocoder.to(device)
denoiser = Denoiser(vocoder, mode="zeros") denoiser = Denoiser(vocoder, mode="zeros")
denoiser = denoiser.to(device) denoiser.to(device)
infer_dataset( infer_dataset(
dl=test_dl, dl=test_dl,

View File

@ -2,21 +2,18 @@
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse import argparse
import datetime as dt
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
import soundfile as sf import soundfile as sf
import torch import torch
from matcha.hifigan.config import v1, v2, v3 from infer import load_vocoder, synthesise, to_waveform
from matcha.hifigan.denoiser import Denoiser from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer from tokenizer import Tokenizer
from train import get_model, get_params from train import get_model, get_params
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser(): def get_parser():
@ -36,7 +33,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=Path, type=Path,
default="matcha/exp-new-3", default="matcha/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -77,60 +74,16 @@ def get_parser():
help="The filename of the wave to save the generated speech", help="The filename of the wave to save the generated speech",
) )
parser.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
)
return parser return parser
def load_vocoder(checkpoint_path):
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def to_waveform(mel, vocoder, denoiser):
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
def process_text(text: str, tokenizer):
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
def synthesise(
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
):
text_processed = process_text(text, tokenizer)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output
@torch.inference_mode() @torch.inference_mode()
def main(): def main():
parser = get_parser() parser = get_parser()
@ -139,6 +92,12 @@ def main():
params.update(vars(args)) params.update(vars(args))
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
@ -151,43 +110,57 @@ def main():
params.model_args.data_statistics.mel_mean = stats["fbank_mean"] params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"] params.model_args.data_statistics.mel_std = stats["fbank_std"]
# Number of ODE Solver steps
params.n_timesteps = 2
# Changes to the speaking rate
params.length_scale = 1.0
# Sampling temperature
params.temperature = 0.667
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval() model.eval()
if not Path(params.vocoder).is_file(): if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist") raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder) vocoder = load_vocoder(params.vocoder)
vocoder.to(device)
denoiser = Denoiser(vocoder, mode="zeros") denoiser = Denoiser(vocoder, mode="zeros")
denoiser.to(device)
# Number of ODE Solver steps
n_timesteps = 2
# Changes to the speaking rate
length_scale = 1.0
# Sampling temperature
temperature = 0.667
output = synthesise( output = synthesise(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
n_timesteps=n_timesteps, n_timesteps=params.n_timesteps,
text=params.input_text, text=params.input_text,
length_scale=length_scale, length_scale=params.length_scale,
temperature=temperature, temperature=params.temperature,
device=device,
) )
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") sf.write(
file=params.output_wav,
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16",
)
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
main() main()