mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
updated
This commit is contained in:
parent
06c2993dfc
commit
b9e3d24c4d
@ -266,10 +266,10 @@ def main():
|
|||||||
raise ValueError(f"{params.vocoder} does not exist")
|
raise ValueError(f"{params.vocoder} does not exist")
|
||||||
|
|
||||||
vocoder = load_vocoder(params.vocoder)
|
vocoder = load_vocoder(params.vocoder)
|
||||||
vocoder = vocoder.to(device)
|
vocoder.to(device)
|
||||||
|
|
||||||
denoiser = Denoiser(vocoder, mode="zeros")
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
denoiser = denoiser.to(device)
|
denoiser.to(device)
|
||||||
|
|
||||||
infer_dataset(
|
infer_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
|
@ -2,21 +2,18 @@
|
|||||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import datetime as dt
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
from matcha.hifigan.config import v1, v2, v3
|
from infer import load_vocoder, synthesise, to_waveform
|
||||||
from matcha.hifigan.denoiser import Denoiser
|
from matcha.hifigan.denoiser import Denoiser
|
||||||
from matcha.hifigan.models import Generator as HiFiGAN
|
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from train import get_model, get_params
|
from train import get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.utils import AttributeDict, setup_logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -36,7 +33,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="matcha/exp-new-3",
|
default="matcha/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -77,60 +74,16 @@ def get_parser():
|
|||||||
help="The filename of the wave to save the generated speech",
|
help="The filename of the wave to save the generated speech",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=22050,
|
||||||
|
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def load_vocoder(checkpoint_path):
|
|
||||||
checkpoint_path = str(checkpoint_path)
|
|
||||||
if checkpoint_path.endswith("v1"):
|
|
||||||
h = AttributeDict(v1)
|
|
||||||
elif checkpoint_path.endswith("v2"):
|
|
||||||
h = AttributeDict(v2)
|
|
||||||
elif checkpoint_path.endswith("v3"):
|
|
||||||
h = AttributeDict(v3)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
|
||||||
|
|
||||||
hifigan = HiFiGAN(h).to("cpu")
|
|
||||||
hifigan.load_state_dict(
|
|
||||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
|
||||||
)
|
|
||||||
_ = hifigan.eval()
|
|
||||||
hifigan.remove_weight_norm()
|
|
||||||
return hifigan
|
|
||||||
|
|
||||||
|
|
||||||
def to_waveform(mel, vocoder, denoiser):
|
|
||||||
audio = vocoder(mel).clamp(-1, 1)
|
|
||||||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
|
||||||
return audio.cpu().squeeze()
|
|
||||||
|
|
||||||
|
|
||||||
def process_text(text: str, tokenizer):
|
|
||||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
|
||||||
x = torch.tensor(x, dtype=torch.long)
|
|
||||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
|
||||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
|
||||||
|
|
||||||
|
|
||||||
def synthesise(
|
|
||||||
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
|
||||||
):
|
|
||||||
text_processed = process_text(text, tokenizer)
|
|
||||||
start_t = dt.datetime.now()
|
|
||||||
output = model.synthesise(
|
|
||||||
text_processed["x"],
|
|
||||||
text_processed["x_lengths"],
|
|
||||||
n_timesteps=n_timesteps,
|
|
||||||
temperature=temperature,
|
|
||||||
spks=spks,
|
|
||||||
length_scale=length_scale,
|
|
||||||
)
|
|
||||||
# merge everything to one dict
|
|
||||||
output.update({"start_t": start_t, **text_processed})
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
@ -139,6 +92,12 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
logging.info("Infer started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.pad_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
@ -151,43 +110,57 @@ def main():
|
|||||||
|
|
||||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
|
||||||
|
# Number of ODE Solver steps
|
||||||
|
params.n_timesteps = 2
|
||||||
|
|
||||||
|
# Changes to the speaking rate
|
||||||
|
params.length_scale = 1.0
|
||||||
|
|
||||||
|
# Sampling temperature
|
||||||
|
params.temperature = 0.667
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if not Path(params.vocoder).is_file():
|
if not Path(params.vocoder).is_file():
|
||||||
raise ValueError(f"{params.vocoder} does not exist")
|
raise ValueError(f"{params.vocoder} does not exist")
|
||||||
|
|
||||||
vocoder = load_vocoder(params.vocoder)
|
vocoder = load_vocoder(params.vocoder)
|
||||||
|
vocoder.to(device)
|
||||||
|
|
||||||
denoiser = Denoiser(vocoder, mode="zeros")
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
|
denoiser.to(device)
|
||||||
# Number of ODE Solver steps
|
|
||||||
n_timesteps = 2
|
|
||||||
|
|
||||||
# Changes to the speaking rate
|
|
||||||
length_scale = 1.0
|
|
||||||
|
|
||||||
# Sampling temperature
|
|
||||||
temperature = 0.667
|
|
||||||
|
|
||||||
output = synthesise(
|
output = synthesise(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
n_timesteps=n_timesteps,
|
n_timesteps=params.n_timesteps,
|
||||||
text=params.input_text,
|
text=params.input_text,
|
||||||
length_scale=length_scale,
|
length_scale=params.length_scale,
|
||||||
temperature=temperature,
|
temperature=params.temperature,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||||
|
|
||||||
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
|
sf.write(
|
||||||
|
file=params.output_wav,
|
||||||
|
data=output["waveform"],
|
||||||
|
samplerate=params.sampling_rate,
|
||||||
|
subtype="PCM_16",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user