minor updates

This commit is contained in:
zr_jin 2024-11-04 16:37:30 +08:00
parent 53ec156198
commit 6e5c3e4032
6 changed files with 329 additions and 58 deletions

View File

@ -56,7 +56,7 @@ function infer() {
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
./matcha/infer.py \ ./matcha/synth.py \
--epoch 1 \ --epoch 1 \
--exp-dir ./matcha/exp \ --exp-dir ./matcha/exp \
--tokens data/tokens.txt \ --tokens data/tokens.txt \

View File

@ -9,14 +9,16 @@ from pathlib import Path
import soundfile as sf import soundfile as sf
import torch import torch
from matcha.hifigan.config import v1, v2, v3 import torch.nn as nn
from matcha.hifigan.denoiser import Denoiser from hifigan.config import v1, v2, v3
from matcha.hifigan.models import Generator as HiFiGAN from hifigan.denoiser import Denoiser
from hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer from tokenizer import Tokenizer
from train import get_model, get_params from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict from icefall.utils import AttributeDict, setup_logger
def get_parser(): def get_parser():
@ -63,24 +65,10 @@ def get_parser():
help="""Path to vocabulary.""", help="""Path to vocabulary.""",
) )
parser.add_argument(
"--input-text",
type=str,
required=True,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=True,
help="The filename of the wave to save the generated speech",
)
return parser return parser
def load_vocoder(checkpoint_path): def load_vocoder(checkpoint_path: Path) -> nn.Module:
checkpoint_path = str(checkpoint_path) checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"): if checkpoint_path.endswith("v1"):
h = AttributeDict(v1) h = AttributeDict(v1)
@ -100,13 +88,15 @@ def load_vocoder(checkpoint_path):
return hifigan return hifigan
def to_waveform(mel, vocoder, denoiser): def to_waveform(
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
) -> torch.Tensor:
audio = vocoder(mel).clamp(-1, 1) audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze() return audio.cpu().squeeze()
def process_text(text: str, tokenizer): def process_text(text: str, tokenizer: Tokenizer) -> dict:
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long) x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
@ -114,8 +104,14 @@ def process_text(text: str, tokenizer):
def synthesise( def synthesise(
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None model: nn.Module,
): tokenizer: Tokenizer,
n_timesteps: int,
text: str,
length_scale: float,
temperature: float,
spks=None,
) -> dict:
text_processed = process_text(text, tokenizer) text_processed = process_text(text, tokenizer)
start_t = dt.datetime.now() start_t = dt.datetime.now()
output = model.synthesise( output = model.synthesise(
@ -131,14 +127,102 @@ def synthesise(
return output return output
def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
vocoder: nn.Module,
denoiser: nn.Module,
tokenizer: Tokenizer,
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""
device = next(model.parameters()).device
num_cuts = 0
log_interval = 5
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])
texts = batch["supervisions"]["text"]
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
for i in range(batch_size):
output = synthesise(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=texts[i],
length_scale=params.length_scale,
temperature=params.temperature,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(
file=params.save_wave_dir / f"{cut_ids[i]}_pred.wav",
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16"
)
sf.write(
file=params.save_wave_dir / f"{cut_ids[i]}_gt.wav",
data=audio[i].numpy(),
samplerate=params.sampling_rate,
subtype="PCM_16"
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
@torch.inference_mode() @torch.inference_mode()
def main(): def main():
parser = get_parser() parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
params = get_params() args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args)) params.update(vars(args))
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
@ -151,49 +235,49 @@ def main():
params.model_args.data_statistics.mel_mean = stats["fbank_mean"] params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"] params.model_args.data_statistics.mel_std = stats["fbank_std"]
# Number of ODE Solver steps
params.n_timesteps = 2
# Changes to the speaking rate
params.length_scale = 1.0
# Sampling temperature
params.temperature = 0.667
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval() model.eval()
# we need cut ids to organize tts results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
if not Path(params.vocoder).is_file(): if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist") raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder) vocoder = load_vocoder(params.vocoder)
vocoder = vocoder.to(device)
denoiser = Denoiser(vocoder, mode="zeros") denoiser = Denoiser(vocoder, mode="zeros")
denoiser = denoiser.to(device)
# Number of ODE Solver steps infer_dataset(
n_timesteps = 2 dl=test_dl,
params=params,
# Changes to the speaking rate
length_scale = 1.0
# Sampling temperature
temperature = 0.667
output = synthesise(
model=model, model=model,
vocoder=vocoder,
denoiser=denoiser,
tokenizer=tokenizer, tokenizer=tokenizer,
n_timesteps=n_timesteps,
text=params.input_text,
length_scale=length_scale,
temperature=temperature,
) )
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main() main()

193
egs/ljspeech/TTS/matcha/synth.py Executable file
View File

@ -0,0 +1,193 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import datetime as dt
import json
import logging
from pathlib import Path
import soundfile as sf
import torch
from matcha.hifigan.config import v1, v2, v3
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=4000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp-new-3",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--vocoder",
type=Path,
default="./generator_v1",
help="Path to the vocoder",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--input-text",
type=str,
required=True,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=True,
help="The filename of the wave to save the generated speech",
)
return parser
def load_vocoder(checkpoint_path):
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def to_waveform(mel, vocoder, denoiser):
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
def process_text(text: str, tokenizer):
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
def synthesise(
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
):
text_processed = process_text(text, tokenizer)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.eval()
if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder)
denoiser = Denoiser(vocoder, mode="zeros")
# Number of ODE Solver steps
n_timesteps = 2
# Changes to the speaking rate
length_scale = 1.0
# Sampling temperature
temperature = 0.667
output = synthesise(
model=model,
tokenizer=tokenizer,
n_timesteps=n_timesteps,
text=params.input_text,
length_scale=length_scale,
temperature=temperature,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -150,7 +150,7 @@ def _get_data_params() -> AttributeDict:
"n_spks": 1, "n_spks": 1,
"n_fft": 1024, "n_fft": 1024,
"n_feats": 80, "n_feats": 80,
"sample_rate": 22050, "sampling_rate": 22050,
"hop_length": 256, "hop_length": 256,
"win_length": 1024, "win_length": 1024,
"f_min": 0, "f_min": 0,
@ -445,11 +445,6 @@ def train_one_epoch(
saved_bad_model = False saved_bad_model = False
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""): def save_bad_model(suffix: str = ""):
save_checkpoint( save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",

View File

@ -234,7 +234,7 @@ def main():
logging.info(f"Number of parameters in discriminator: {num_param_d}") logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}") logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
# we need cut ids to display recognition results. # we need cut ids to organize tts results.
args.return_cuts = True args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args) ljspeech = LJSpeechTtsDataModule(args)

View File

@ -18,7 +18,6 @@
from tokenizer import Tokenizer from tokenizer import Tokenizer
from train import get_model, get_params from train import get_model, get_params
from vits import VITS
def test_model_type(model_type): def test_model_type(model_type):