icefall/egs/ljspeech/TTS/matcha/inference.py
2024-10-29 15:04:04 +08:00

200 lines
5.3 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import datetime as dt
import json
import logging
from pathlib import Path
import soundfile as sf
import torch
from matcha.hifigan.config import v1, v2, v3
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=4000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp-new-3",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--vocoder",
type=Path,
default="./generator_v1",
help="Path to the vocoder",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--input-text",
type=str,
required=True,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=True,
help="The filename of the wave to save the generated speech",
)
return parser
def load_vocoder(checkpoint_path):
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def to_waveform(mel, vocoder, denoiser):
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
def process_text(text: str, tokenizer):
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
def synthesise(
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
):
text_processed = process_text(text, tokenizer)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.eval()
if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder)
denoiser = Denoiser(vocoder, mode="zeros")
# Number of ODE Solver steps
n_timesteps = 2
# Changes to the speaking rate
length_scale = 1.0
# Sampling temperature
temperature = 0.667
output = synthesise(
model=model,
tokenizer=tokenizer,
n_timesteps=n_timesteps,
text=params.input_text,
length_scale=length_scale,
temperature=temperature,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()