mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 03:52:18 +00:00
minor adjustment
This commit is contained in:
parent
fb34991566
commit
64d8a430d6
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -56,7 +56,7 @@ function infer() {
|
|||||||
|
|
||||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||||
|
|
||||||
./matcha/synth.py \
|
./matcha/infer.py \
|
||||||
--epoch 1 \
|
--epoch 1 \
|
||||||
--exp-dir ./matcha/exp \
|
--exp-dir ./matcha/exp \
|
||||||
--tokens data/tokens.txt \
|
--tokens data/tokens.txt \
|
||||||
|
@ -65,6 +65,28 @@ def get_parser():
|
|||||||
help="""Path to vocabulary.""",
|
help="""Path to vocabulary.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The following arguments are used for inference on single text
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-text",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="The text to generate speech for",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-wav",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="The filename of the wave to save the generated speech",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=22050,
|
||||||
|
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -103,7 +125,7 @@ def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
|
|||||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||||
|
|
||||||
|
|
||||||
def synthesise(
|
def synthesize(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
n_timesteps: int,
|
n_timesteps: int,
|
||||||
@ -169,7 +191,7 @@ def infer_dataset(
|
|||||||
cut_ids = [cut.id for cut in batch["cut"]]
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
output = synthesise(
|
output = synthesize(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
n_timesteps=params.n_timesteps,
|
n_timesteps=params.n_timesteps,
|
||||||
@ -271,15 +293,35 @@ def main():
|
|||||||
denoiser = Denoiser(vocoder, mode="zeros")
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
denoiser.to(device)
|
denoiser.to(device)
|
||||||
|
|
||||||
infer_dataset(
|
if params.input_text is not None and params.output_wav is not None:
|
||||||
dl=test_dl,
|
logging.info("Synthesizing a single text")
|
||||||
params=params,
|
output = synthesize(
|
||||||
model=model,
|
model=model,
|
||||||
vocoder=vocoder,
|
tokenizer=tokenizer,
|
||||||
denoiser=denoiser,
|
n_timesteps=params.n_timesteps,
|
||||||
tokenizer=tokenizer,
|
text=params.input_text,
|
||||||
)
|
length_scale=params.length_scale,
|
||||||
|
temperature=params.temperature,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||||
|
|
||||||
|
sf.write(
|
||||||
|
file=params.output_wav,
|
||||||
|
data=output["waveform"],
|
||||||
|
samplerate=params.sampling_rate,
|
||||||
|
subtype="PCM_16",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Decoding the test set")
|
||||||
|
infer_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
vocoder=vocoder,
|
||||||
|
denoiser=denoiser,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -1,166 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import soundfile as sf
|
|
||||||
import torch
|
|
||||||
from hifigan.denoiser import Denoiser
|
|
||||||
from infer import load_vocoder, synthesise, to_waveform
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
from train import get_model, get_params
|
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=4000,
|
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
|
||||||
Note: Epoch counts from 1.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=Path,
|
|
||||||
default="matcha/exp",
|
|
||||||
help="""The experiment dir.
|
|
||||||
It specifies the directory where all training related
|
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--vocoder",
|
|
||||||
type=Path,
|
|
||||||
default="./generator_v1",
|
|
||||||
help="Path to the vocoder",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=Path,
|
|
||||||
default="data/tokens.txt",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--cmvn",
|
|
||||||
type=str,
|
|
||||||
default="data/fbank/cmvn.json",
|
|
||||||
help="""Path to vocabulary.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-text",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The text to generate speech for",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-wav",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The filename of the wave to save the generated speech",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--sampling-rate",
|
|
||||||
type=int,
|
|
||||||
default=22050,
|
|
||||||
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
params = get_params()
|
|
||||||
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
logging.info("Infer started")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
|
||||||
params.blank_id = tokenizer.pad_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
|
||||||
params.model_args.n_vocab = params.vocab_size
|
|
||||||
|
|
||||||
with open(params.cmvn) as f:
|
|
||||||
stats = json.load(f)
|
|
||||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
|
||||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
|
||||||
|
|
||||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
|
||||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
|
||||||
|
|
||||||
# Number of ODE Solver steps
|
|
||||||
params.n_timesteps = 2
|
|
||||||
|
|
||||||
# Changes to the speaking rate
|
|
||||||
params.length_scale = 1.0
|
|
||||||
|
|
||||||
# Sampling temperature
|
|
||||||
params.temperature = 0.667
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
|
||||||
model = get_model(params)
|
|
||||||
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if not Path(params.vocoder).is_file():
|
|
||||||
raise ValueError(f"{params.vocoder} does not exist")
|
|
||||||
|
|
||||||
vocoder = load_vocoder(params.vocoder)
|
|
||||||
vocoder.to(device)
|
|
||||||
|
|
||||||
denoiser = Denoiser(vocoder, mode="zeros")
|
|
||||||
denoiser.to(device)
|
|
||||||
|
|
||||||
output = synthesise(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
n_timesteps=params.n_timesteps,
|
|
||||||
text=params.input_text,
|
|
||||||
length_scale=params.length_scale,
|
|
||||||
temperature=params.temperature,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
|
||||||
|
|
||||||
sf.write(
|
|
||||||
file=params.output_wav,
|
|
||||||
data=output["waveform"],
|
|
||||||
samplerate=params.sampling_rate,
|
|
||||||
subtype="PCM_16",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
main()
|
|
Loading…
x
Reference in New Issue
Block a user