mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
updated
This commit is contained in:
parent
06c2993dfc
commit
b9e3d24c4d
@ -266,10 +266,10 @@ def main():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
vocoder = vocoder.to(device)
|
||||
vocoder.to(device)
|
||||
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
denoiser = denoiser.to(device)
|
||||
denoiser.to(device)
|
||||
|
||||
infer_dataset(
|
||||
dl=test_dl,
|
||||
|
@ -2,21 +2,18 @@
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from matcha.hifigan.config import v1, v2, v3
|
||||
from infer import load_vocoder, synthesise, to_waveform
|
||||
from matcha.hifigan.denoiser import Denoiser
|
||||
from matcha.hifigan.models import Generator as HiFiGAN
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -36,7 +33,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp-new-3",
|
||||
default="matcha/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -77,60 +74,16 @@ def get_parser():
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=22050,
|
||||
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_vocoder(checkpoint_path):
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if checkpoint_path.endswith("v1"):
|
||||
h = AttributeDict(v1)
|
||||
elif checkpoint_path.endswith("v2"):
|
||||
h = AttributeDict(v2)
|
||||
elif checkpoint_path.endswith("v3"):
|
||||
h = AttributeDict(v3)
|
||||
else:
|
||||
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
||||
|
||||
hifigan = HiFiGAN(h).to("cpu")
|
||||
hifigan.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||
)
|
||||
_ = hifigan.eval()
|
||||
hifigan.remove_weight_norm()
|
||||
return hifigan
|
||||
|
||||
|
||||
def to_waveform(mel, vocoder, denoiser):
|
||||
audio = vocoder(mel).clamp(-1, 1)
|
||||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||
return audio.cpu().squeeze()
|
||||
|
||||
|
||||
def process_text(text: str, tokenizer):
|
||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.long)
|
||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||
|
||||
|
||||
def synthesise(
|
||||
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
||||
):
|
||||
text_processed = process_text(text, tokenizer)
|
||||
start_t = dt.datetime.now()
|
||||
output = model.synthesise(
|
||||
text_processed["x"],
|
||||
text_processed["x_lengths"],
|
||||
n_timesteps=n_timesteps,
|
||||
temperature=temperature,
|
||||
spks=spks,
|
||||
length_scale=length_scale,
|
||||
)
|
||||
# merge everything to one dict
|
||||
output.update({"start_t": start_t, **text_processed})
|
||||
return output
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
@ -139,6 +92,12 @@ def main():
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info("Infer started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
@ -151,43 +110,57 @@ def main():
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
# Number of ODE Solver steps
|
||||
params.n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
params.length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
params.temperature = 0.667
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if not Path(params.vocoder).is_file():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
vocoder.to(device)
|
||||
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
|
||||
# Number of ODE Solver steps
|
||||
n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
temperature = 0.667
|
||||
denoiser.to(device)
|
||||
|
||||
output = synthesise(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=n_timesteps,
|
||||
n_timesteps=params.n_timesteps,
|
||||
text=params.input_text,
|
||||
length_scale=length_scale,
|
||||
temperature=temperature,
|
||||
length_scale=params.length_scale,
|
||||
temperature=params.temperature,
|
||||
device=device,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
|
||||
sf.write(
|
||||
file=params.output_wav,
|
||||
data=output["waveform"],
|
||||
samplerate=params.sampling_rate,
|
||||
subtype="PCM_16",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user