mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
minor updates
This commit is contained in:
parent
53ec156198
commit
6e5c3e4032
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -56,7 +56,7 @@ function infer() {
|
||||
|
||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
|
||||
./matcha/infer.py \
|
||||
./matcha/synth.py \
|
||||
--epoch 1 \
|
||||
--exp-dir ./matcha/exp \
|
||||
--tokens data/tokens.txt \
|
||||
|
@ -9,14 +9,16 @@ from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from matcha.hifigan.config import v1, v2, v3
|
||||
from matcha.hifigan.denoiser import Denoiser
|
||||
from matcha.hifigan.models import Generator as HiFiGAN
|
||||
import torch.nn as nn
|
||||
from hifigan.config import v1, v2, v3
|
||||
from hifigan.denoiser import Denoiser
|
||||
from hifigan.models import Generator as HiFiGAN
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -63,24 +65,10 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_vocoder(checkpoint_path):
|
||||
def load_vocoder(checkpoint_path: Path) -> nn.Module:
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if checkpoint_path.endswith("v1"):
|
||||
h = AttributeDict(v1)
|
||||
@ -100,13 +88,15 @@ def load_vocoder(checkpoint_path):
|
||||
return hifigan
|
||||
|
||||
|
||||
def to_waveform(mel, vocoder, denoiser):
|
||||
def to_waveform(
|
||||
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
|
||||
) -> torch.Tensor:
|
||||
audio = vocoder(mel).clamp(-1, 1)
|
||||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||
return audio.cpu().squeeze()
|
||||
|
||||
|
||||
def process_text(text: str, tokenizer):
|
||||
def process_text(text: str, tokenizer: Tokenizer) -> dict:
|
||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.long)
|
||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
||||
@ -114,8 +104,14 @@ def process_text(text: str, tokenizer):
|
||||
|
||||
|
||||
def synthesise(
|
||||
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
||||
):
|
||||
model: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
n_timesteps: int,
|
||||
text: str,
|
||||
length_scale: float,
|
||||
temperature: float,
|
||||
spks=None,
|
||||
) -> dict:
|
||||
text_processed = process_text(text, tokenizer)
|
||||
start_t = dt.datetime.now()
|
||||
output = model.synthesise(
|
||||
@ -131,14 +127,102 @@ def synthesise(
|
||||
return output
|
||||
|
||||
|
||||
def infer_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
vocoder: nn.Module,
|
||||
denoiser: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
) -> None:
|
||||
"""Decode dataset.
|
||||
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
"""
|
||||
|
||||
device = next(model.parameters()).device
|
||||
num_cuts = 0
|
||||
log_interval = 5
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
audio = batch["audio"]
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
for i in range(batch_size):
|
||||
output = synthesise(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=params.n_timesteps,
|
||||
text=texts[i],
|
||||
length_scale=params.length_scale,
|
||||
temperature=params.temperature,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(
|
||||
file=params.save_wave_dir / f"{cut_ids[i]}_pred.wav",
|
||||
data=output["waveform"],
|
||||
samplerate=params.sampling_rate,
|
||||
subtype="PCM_16"
|
||||
)
|
||||
sf.write(
|
||||
file=params.save_wave_dir / f"{cut_ids[i]}_gt.wav",
|
||||
data=audio[i].numpy(),
|
||||
samplerate=params.sampling_rate,
|
||||
subtype="PCM_16"
|
||||
)
|
||||
|
||||
num_cuts += batch_size
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LJSpeechTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
params = get_params()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}"
|
||||
|
||||
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||
params.save_wav_dir = params.res_dir / "wav"
|
||||
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||
logging.info("Infer started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
@ -151,49 +235,49 @@ def main():
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
# Number of ODE Solver steps
|
||||
params.n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
params.length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
params.temperature = 0.667
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
|
||||
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# we need cut ids to organize tts results.
|
||||
args.return_cuts = True
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
test_cuts = ljspeech.test_cuts()
|
||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||
|
||||
if not Path(params.vocoder).is_file():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
vocoder = vocoder.to(device)
|
||||
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
denoiser = denoiser.to(device)
|
||||
|
||||
# Number of ODE Solver steps
|
||||
n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
temperature = 0.667
|
||||
|
||||
output = synthesise(
|
||||
infer_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
denoiser=denoiser,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=n_timesteps,
|
||||
text=params.input_text,
|
||||
length_scale=length_scale,
|
||||
temperature=temperature,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
||||
|
193
egs/ljspeech/TTS/matcha/synth.py
Executable file
193
egs/ljspeech/TTS/matcha/synth.py
Executable file
@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from matcha.hifigan.config import v1, v2, v3
|
||||
from matcha.hifigan.denoiser import Denoiser
|
||||
from matcha.hifigan.models import Generator as HiFiGAN
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=4000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp-new-3",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=Path,
|
||||
default="./generator_v1",
|
||||
help="Path to the vocoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default="data/tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_vocoder(checkpoint_path):
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if checkpoint_path.endswith("v1"):
|
||||
h = AttributeDict(v1)
|
||||
elif checkpoint_path.endswith("v2"):
|
||||
h = AttributeDict(v2)
|
||||
elif checkpoint_path.endswith("v3"):
|
||||
h = AttributeDict(v3)
|
||||
else:
|
||||
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
||||
|
||||
hifigan = HiFiGAN(h).to("cpu")
|
||||
hifigan.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||
)
|
||||
_ = hifigan.eval()
|
||||
hifigan.remove_weight_norm()
|
||||
return hifigan
|
||||
|
||||
|
||||
def to_waveform(mel, vocoder, denoiser):
|
||||
audio = vocoder(mel).clamp(-1, 1)
|
||||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||
return audio.cpu().squeeze()
|
||||
|
||||
|
||||
def process_text(text: str, tokenizer):
|
||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.long)
|
||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||
|
||||
|
||||
def synthesise(
|
||||
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
||||
):
|
||||
text_processed = process_text(text, tokenizer)
|
||||
start_t = dt.datetime.now()
|
||||
output = model.synthesise(
|
||||
text_processed["x"],
|
||||
text_processed["x_lengths"],
|
||||
n_timesteps=n_timesteps,
|
||||
temperature=temperature,
|
||||
spks=spks,
|
||||
length_scale=length_scale,
|
||||
)
|
||||
# merge everything to one dict
|
||||
output.update({"start_t": start_t, **text_processed})
|
||||
return output
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.eval()
|
||||
|
||||
if not Path(params.vocoder).is_file():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
|
||||
# Number of ODE Solver steps
|
||||
n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
temperature = 0.667
|
||||
|
||||
output = synthesise(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=n_timesteps,
|
||||
text=params.input_text,
|
||||
length_scale=length_scale,
|
||||
temperature=temperature,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
@ -150,7 +150,7 @@ def _get_data_params() -> AttributeDict:
|
||||
"n_spks": 1,
|
||||
"n_fft": 1024,
|
||||
"n_feats": 80,
|
||||
"sample_rate": 22050,
|
||||
"sampling_rate": 22050,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"f_min": 0,
|
||||
@ -445,11 +445,6 @@ def train_one_epoch(
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
|
@ -234,7 +234,7 @@ def main():
|
||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
# we need cut ids to organize tts results.
|
||||
args.return_cuts = True
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
|
@ -18,7 +18,6 @@
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from vits import VITS
|
||||
|
||||
|
||||
def test_model_type(model_type):
|
||||
|
Loading…
x
Reference in New Issue
Block a user