mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
minor adjustment
This commit is contained in:
parent
fb34991566
commit
64d8a430d6
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -56,7 +56,7 @@ function infer() {
|
||||
|
||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
|
||||
./matcha/synth.py \
|
||||
./matcha/infer.py \
|
||||
--epoch 1 \
|
||||
--exp-dir ./matcha/exp \
|
||||
--tokens data/tokens.txt \
|
||||
|
@ -65,6 +65,28 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
# The following arguments are used for inference on single text
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=22050,
|
||||
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -103,7 +125,7 @@ def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
|
||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||
|
||||
|
||||
def synthesise(
|
||||
def synthesize(
|
||||
model: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
n_timesteps: int,
|
||||
@ -169,7 +191,7 @@ def infer_dataset(
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
for i in range(batch_size):
|
||||
output = synthesise(
|
||||
output = synthesize(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=params.n_timesteps,
|
||||
@ -271,15 +293,35 @@ def main():
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
denoiser.to(device)
|
||||
|
||||
infer_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
denoiser=denoiser,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if params.input_text is not None and params.output_wav is not None:
|
||||
logging.info("Synthesizing a single text")
|
||||
output = synthesize(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=params.n_timesteps,
|
||||
text=params.input_text,
|
||||
length_scale=params.length_scale,
|
||||
temperature=params.temperature,
|
||||
device=device,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(
|
||||
file=params.output_wav,
|
||||
data=output["waveform"],
|
||||
samplerate=params.sampling_rate,
|
||||
subtype="PCM_16",
|
||||
)
|
||||
else:
|
||||
logging.info("Decoding the test set")
|
||||
infer_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
denoiser=denoiser,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,166 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from hifigan.denoiser import Denoiser
|
||||
from infer import load_vocoder, synthesise, to_waveform
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=4000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=Path,
|
||||
default="./generator_v1",
|
||||
help="Path to the vocoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default="data/tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=22050,
|
||||
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info("Infer started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
# Number of ODE Solver steps
|
||||
params.n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
params.length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
params.temperature = 0.667
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if not Path(params.vocoder).is_file():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
vocoder.to(device)
|
||||
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
denoiser.to(device)
|
||||
|
||||
output = synthesise(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=params.n_timesteps,
|
||||
text=params.input_text,
|
||||
length_scale=params.length_scale,
|
||||
temperature=params.temperature,
|
||||
device=device,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(
|
||||
file=params.output_wav,
|
||||
data=output["waveform"],
|
||||
samplerate=params.sampling_rate,
|
||||
subtype="PCM_16",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user