mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
First working version
This commit is contained in:
parent
e4d40baaf5
commit
e4f08c74f7
4
egs/baker_zh/TTS/.gitignore
vendored
4
egs/baker_zh/TTS/.gitignore
vendored
@ -1 +1,5 @@
|
|||||||
path.sh
|
path.sh
|
||||||
|
*.wav
|
||||||
|
generator_v1
|
||||||
|
generator_v2
|
||||||
|
generator_v3
|
||||||
|
@ -30,6 +30,8 @@ punctuations_re = [
|
|||||||
("’", "'"),
|
("’", "'"),
|
||||||
(":", ":"),
|
(":", ":"),
|
||||||
("、", ","),
|
("、", ","),
|
||||||
|
("B", "逼"),
|
||||||
|
("P", "批"),
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -108,7 +110,7 @@ def main():
|
|||||||
text_list = split_text(text)
|
text_list = split_text(text)
|
||||||
tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True)
|
tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True)
|
||||||
|
|
||||||
c.supervisions[0].tokens = tokens
|
c.tokens = tokens
|
||||||
|
|
||||||
cuts.to_file(args.out_file)
|
cuts.to_file(args.out_file)
|
||||||
|
|
||||||
|
@ -43,6 +43,22 @@ def generate_token_list() -> List[str]:
|
|||||||
t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]
|
t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]
|
||||||
token_set.add(t)
|
token_set.add(t)
|
||||||
|
|
||||||
|
no_digit = set()
|
||||||
|
for t in token_set:
|
||||||
|
if t[-1] not in "1234":
|
||||||
|
no_digit.add(t)
|
||||||
|
else:
|
||||||
|
no_digit.add(t[:-1])
|
||||||
|
|
||||||
|
no_digit.add("dei")
|
||||||
|
no_digit.add("tou")
|
||||||
|
no_digit.add("dia")
|
||||||
|
|
||||||
|
for t in no_digit:
|
||||||
|
token_set.add(t)
|
||||||
|
for i in range(1, 5):
|
||||||
|
token_set.add(f"{t}{i}")
|
||||||
|
|
||||||
ans = list(token_set)
|
ans = list(token_set)
|
||||||
ans.sort()
|
ans.sort()
|
||||||
|
|
||||||
|
0
egs/baker_zh/TTS/matcha/__init__.py
Normal file
0
egs/baker_zh/TTS/matcha/__init__.py
Normal file
332
egs/baker_zh/TTS/matcha/infer.py
Executable file
332
egs/baker_zh/TTS/matcha/infer.py
Executable file
@ -0,0 +1,332 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hifigan.config import v1, v2, v3
|
||||||
|
from hifigan.denoiser import Denoiser
|
||||||
|
from hifigan.models import Generator as HiFiGAN
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
from tts_datamodule import BakerZhTtsDataModule
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.utils import AttributeDict, setup_logger
|
||||||
|
from local.convert_text_to_tokens import split_text
|
||||||
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=4000,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=Path,
|
||||||
|
default="matcha/exp",
|
||||||
|
help="""The experiment dir.
|
||||||
|
It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocoder",
|
||||||
|
type=Path,
|
||||||
|
default="./generator_v1",
|
||||||
|
help="Path to the vocoder",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=Path,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cmvn",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank/cmvn.json",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
# The following arguments are used for inference on single text
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-text",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="The text to generate speech for",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-wav",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="The filename of the wave to save the generated speech",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=22050,
|
||||||
|
help="The sampling rate of the generated speech (default: 22050 for baker_zh)",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocoder(checkpoint_path: Path) -> nn.Module:
|
||||||
|
checkpoint_path = str(checkpoint_path)
|
||||||
|
if checkpoint_path.endswith("v1"):
|
||||||
|
h = AttributeDict(v1)
|
||||||
|
elif checkpoint_path.endswith("v2"):
|
||||||
|
h = AttributeDict(v2)
|
||||||
|
elif checkpoint_path.endswith("v3"):
|
||||||
|
h = AttributeDict(v3)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
||||||
|
|
||||||
|
hifigan = HiFiGAN(h).to("cpu")
|
||||||
|
hifigan.load_state_dict(
|
||||||
|
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||||
|
)
|
||||||
|
_ = hifigan.eval()
|
||||||
|
hifigan.remove_weight_norm()
|
||||||
|
return hifigan
|
||||||
|
|
||||||
|
|
||||||
|
def to_waveform(
|
||||||
|
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
|
||||||
|
) -> torch.Tensor:
|
||||||
|
audio = vocoder(mel).clamp(-1, 1)
|
||||||
|
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||||
|
return audio.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
|
||||||
|
text = split_text(text)
|
||||||
|
tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True)
|
||||||
|
|
||||||
|
x = tokenizer.texts_to_token_ids([tokens])
|
||||||
|
x = torch.tensor(x, dtype=torch.long, device=device)
|
||||||
|
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
|
||||||
|
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
n_timesteps: int,
|
||||||
|
text: str,
|
||||||
|
length_scale: float,
|
||||||
|
temperature: float,
|
||||||
|
device: str = "cpu",
|
||||||
|
spks=None,
|
||||||
|
) -> dict:
|
||||||
|
text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
output = model.synthesise(
|
||||||
|
text_processed["x"],
|
||||||
|
text_processed["x_lengths"],
|
||||||
|
n_timesteps=n_timesteps,
|
||||||
|
temperature=temperature,
|
||||||
|
spks=spks,
|
||||||
|
length_scale=length_scale,
|
||||||
|
)
|
||||||
|
# merge everything to one dict
|
||||||
|
output.update({"start_t": start_t, **text_processed})
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def infer_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
vocoder: nn.Module,
|
||||||
|
denoiser: nn.Module,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
) -> None:
|
||||||
|
"""Decode dataset.
|
||||||
|
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
tokenizer:
|
||||||
|
Used to convert text to phonemes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
num_cuts = 0
|
||||||
|
log_interval = 5
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
batch_size = len(batch["tokens"])
|
||||||
|
|
||||||
|
texts = [c.supervisions[0].normalized_text for c in batch["cut"]]
|
||||||
|
|
||||||
|
audio = batch["audio"]
|
||||||
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
output = synthesize(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
n_timesteps=params.n_timesteps,
|
||||||
|
text=texts[i],
|
||||||
|
length_scale=params.length_scale,
|
||||||
|
temperature=params.temperature,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||||
|
|
||||||
|
sf.write(
|
||||||
|
file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav",
|
||||||
|
data=output["waveform"],
|
||||||
|
samplerate=params.data_args.sampling_rate,
|
||||||
|
subtype="PCM_16",
|
||||||
|
)
|
||||||
|
sf.write(
|
||||||
|
file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav",
|
||||||
|
data=audio[i].numpy(),
|
||||||
|
samplerate=params.data_args.sampling_rate,
|
||||||
|
subtype="PCM_16",
|
||||||
|
)
|
||||||
|
|
||||||
|
num_cuts += batch_size
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
BakerZhTtsDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||||
|
params.save_wav_dir = params.res_dir / "wav"
|
||||||
|
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||||
|
logging.info("Infer started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
params.model_args.n_vocab = params.vocab_size
|
||||||
|
|
||||||
|
with open(params.cmvn) as f:
|
||||||
|
stats = json.load(f)
|
||||||
|
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
|
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
|
||||||
|
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||||
|
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||||
|
|
||||||
|
# Number of ODE Solver steps
|
||||||
|
params.n_timesteps = 2
|
||||||
|
|
||||||
|
# Changes to the speaking rate
|
||||||
|
params.length_scale = 1.0
|
||||||
|
|
||||||
|
# Sampling temperature
|
||||||
|
params.temperature = 0.667
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# we need cut ids to organize tts results.
|
||||||
|
args.return_cuts = True
|
||||||
|
baker_zh = BakerZhTtsDataModule(args)
|
||||||
|
|
||||||
|
test_cuts = baker_zh.test_cuts()
|
||||||
|
test_dl = baker_zh.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
if not Path(params.vocoder).is_file():
|
||||||
|
raise ValueError(f"{params.vocoder} does not exist")
|
||||||
|
|
||||||
|
vocoder = load_vocoder(params.vocoder)
|
||||||
|
vocoder.to(device)
|
||||||
|
|
||||||
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
|
denoiser.to(device)
|
||||||
|
|
||||||
|
if params.input_text is not None and params.output_wav is not None:
|
||||||
|
logging.info("Synthesizing a single text")
|
||||||
|
output = synthesize(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
n_timesteps=params.n_timesteps,
|
||||||
|
text=params.input_text,
|
||||||
|
length_scale=params.length_scale,
|
||||||
|
temperature=params.temperature,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||||
|
|
||||||
|
sf.write(
|
||||||
|
file=params.output_wav,
|
||||||
|
data=output["waveform"],
|
||||||
|
samplerate=params.sampling_rate,
|
||||||
|
subtype="PCM_16",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Decoding the test set")
|
||||||
|
infer_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
vocoder=vocoder,
|
||||||
|
denoiser=denoiser,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -80,7 +80,7 @@ class Tokenizer(object):
|
|||||||
token_ids = []
|
token_ids = []
|
||||||
for t in tokens_list:
|
for t in tokens_list:
|
||||||
if t not in self.token2id:
|
if t not in self.token2id:
|
||||||
logging.warning(f"Skip OOV {t}")
|
logging.warning(f"Skip OOV {t} {sentence}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id:
|
if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id:
|
||||||
|
@ -315,7 +315,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param
|
|||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
|
|
||||||
tokens = tokenizer.tokens_to_token_ids(tokens, intersperse_blank=True)
|
tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user