mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
216 lines
6.5 KiB
Python
Executable File
216 lines
6.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import datetime as dt
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import json
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import torch
|
|
from matcha.hifigan.config import v1, v2, v3
|
|
from matcha.hifigan.denoiser import Denoiser
|
|
from tokenizer import Tokenizer
|
|
from matcha.hifigan.models import Generator as HiFiGAN
|
|
from matcha.text import sequence_to_text, text_to_sequence
|
|
from matcha.utils.utils import intersperse
|
|
from tqdm.auto import tqdm
|
|
from train import get_model, get_params
|
|
|
|
from icefall.checkpoint import load_checkpoint
|
|
from icefall.utils import AttributeDict
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--epoch",
|
|
type=int,
|
|
default=2810,
|
|
help="""It specifies the checkpoint to use for decoding.
|
|
Note: Epoch counts from 1.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--exp-dir",
|
|
type=Path,
|
|
default="matcha/exp-new-3",
|
|
help="""The experiment dir.
|
|
It specifies the directory where all training related
|
|
files, e.g., checkpoints, log, etc, are saved
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--tokens",
|
|
type=Path,
|
|
default="data/tokens.txt",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--cmvn",
|
|
type=str,
|
|
default="data/fbank/cmvn.json",
|
|
help="""Path to vocabulary.""",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def load_vocoder(checkpoint_path):
|
|
if checkpoint_path.endswith("v1"):
|
|
h = AttributeDict(v1)
|
|
elif checkpoint_path.endswith("v2"):
|
|
h = AttributeDict(v2)
|
|
elif checkpoint_path.endswith("v3"):
|
|
h = AttributeDict(v3)
|
|
else:
|
|
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
|
|
|
hifigan = HiFiGAN(h).to("cpu")
|
|
hifigan.load_state_dict(
|
|
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
|
)
|
|
_ = hifigan.eval()
|
|
hifigan.remove_weight_norm()
|
|
return hifigan
|
|
|
|
|
|
def to_waveform(mel, vocoder, denoiser):
|
|
audio = vocoder(mel).clamp(-1, 1)
|
|
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
|
return audio.cpu().squeeze()
|
|
|
|
|
|
def save_to_folder(filename: str, output: dict, folder: str):
|
|
folder = Path(folder)
|
|
folder.mkdir(exist_ok=True, parents=True)
|
|
np.save(folder / f"{filename}", output["mel"].cpu().numpy())
|
|
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
|
|
|
|
|
|
def process_text(text: str, tokenizer):
|
|
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
|
x = torch.tensor(x, dtype=torch.long)
|
|
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
|
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
|
|
|
|
|
def synthesise(
|
|
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
|
):
|
|
text_processed = process_text(text, tokenizer)
|
|
start_t = dt.datetime.now()
|
|
output = model.synthesise(
|
|
text_processed["x"],
|
|
text_processed["x_lengths"],
|
|
n_timesteps=n_timesteps,
|
|
temperature=temperature,
|
|
spks=spks,
|
|
length_scale=length_scale,
|
|
)
|
|
print("output.shape", list(output.keys()), output["mel"].shape)
|
|
# merge everything to one dict
|
|
output.update({"start_t": start_t, **text_processed})
|
|
return output
|
|
|
|
|
|
@torch.inference_mode()
|
|
def main():
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
params = get_params()
|
|
|
|
params.update(vars(args))
|
|
|
|
tokenizer = Tokenizer(params.tokens)
|
|
params.blank_id = tokenizer.pad_id
|
|
params.vocab_size = tokenizer.vocab_size
|
|
params.model_args.n_vocab = params.vocab_size
|
|
|
|
with open(params.cmvn) as f:
|
|
stats = json.load(f)
|
|
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
|
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
|
|
|
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
|
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
|
logging.info(params)
|
|
|
|
logging.info("About to create model")
|
|
model = get_model(params)
|
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
model.eval()
|
|
|
|
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
|
|
denoiser = Denoiser(vocoder, mode="zeros")
|
|
|
|
texts = [
|
|
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
|
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
|
]
|
|
|
|
# Number of ODE Solver steps
|
|
n_timesteps = 2
|
|
|
|
# Changes to the speaking rate
|
|
length_scale = 1.0
|
|
|
|
# Sampling temperature
|
|
temperature = 0.667
|
|
|
|
outputs, rtfs = [], []
|
|
rtfs_w = []
|
|
for i, text in enumerate(tqdm(texts)):
|
|
output = synthesise(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
n_timesteps=n_timesteps,
|
|
text=text,
|
|
length_scale=length_scale,
|
|
temperature=temperature,
|
|
) # , torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
|
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
|
|
|
# Compute Real Time Factor (RTF) with HiFi-GAN
|
|
t = (dt.datetime.now() - output["start_t"]).total_seconds()
|
|
rtf_w = t * 22050 / (output["waveform"].shape[-1])
|
|
|
|
# Pretty print
|
|
print(f"{'*' * 53}")
|
|
print(f"Input text - {i}")
|
|
print(f"{'-' * 53}")
|
|
print(output["x_orig"])
|
|
print(f"{'*' * 53}")
|
|
print(f"Phonetised text - {i}")
|
|
print(f"{'-' * 53}")
|
|
print(output["x"])
|
|
print(f"{'*' * 53}")
|
|
print(f"RTF:\t\t{output['rtf']:.6f}")
|
|
print(f"RTF Waveform:\t{rtf_w:.6f}")
|
|
rtfs.append(output["rtf"])
|
|
rtfs_w.append(rtf_w)
|
|
|
|
# Save the generated waveform
|
|
save_to_folder(i, output, folder=f"./my-output-{params.epoch}")
|
|
|
|
print(f"Number of ODE steps: {n_timesteps}")
|
|
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
|
|
print(
|
|
f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
main()
|