diff --git a/egs/baker_zh/TTS/.gitignore b/egs/baker_zh/TTS/.gitignore index 9c40c787f..8a3b1af65 100644 --- a/egs/baker_zh/TTS/.gitignore +++ b/egs/baker_zh/TTS/.gitignore @@ -1 +1,5 @@ path.sh +*.wav +generator_v1 +generator_v2 +generator_v3 diff --git a/egs/baker_zh/TTS/local/convert_text_to_tokens.py b/egs/baker_zh/TTS/local/convert_text_to_tokens.py index a20165089..9d5beafae 100755 --- a/egs/baker_zh/TTS/local/convert_text_to_tokens.py +++ b/egs/baker_zh/TTS/local/convert_text_to_tokens.py @@ -30,6 +30,8 @@ punctuations_re = [ ("’", "'"), (":", ":"), ("、", ","), + ("B", "逼"), + ("P", "批"), ] ] @@ -108,7 +110,7 @@ def main(): text_list = split_text(text) tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True) - c.supervisions[0].tokens = tokens + c.tokens = tokens cuts.to_file(args.out_file) diff --git a/egs/baker_zh/TTS/local/generate_tokens.py b/egs/baker_zh/TTS/local/generate_tokens.py index 0f469aaf3..b2abe1a71 100755 --- a/egs/baker_zh/TTS/local/generate_tokens.py +++ b/egs/baker_zh/TTS/local/generate_tokens.py @@ -43,6 +43,22 @@ def generate_token_list() -> List[str]: t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] token_set.add(t) + no_digit = set() + for t in token_set: + if t[-1] not in "1234": + no_digit.add(t) + else: + no_digit.add(t[:-1]) + + no_digit.add("dei") + no_digit.add("tou") + no_digit.add("dia") + + for t in no_digit: + token_set.add(t) + for i in range(1, 5): + token_set.add(f"{t}{i}") + ans = list(token_set) ans.sort() diff --git a/egs/baker_zh/TTS/matcha/__init__.py b/egs/baker_zh/TTS/matcha/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/baker_zh/TTS/matcha/infer.py b/egs/baker_zh/TTS/matcha/infer.py new file mode 100755 index 000000000..b7e785a04 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/infer.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +import torch.nn as nn +from hifigan.config import v1, v2, v3 +from hifigan.denoiser import Denoiser +from hifigan.models import Generator as HiFiGAN +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import BakerZhTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger +from local.convert_text_to_tokens import split_text +from pypinyin import lazy_pinyin, Style + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + # The following arguments are used for inference on single text + parser.add_argument( + "--input-text", + type=str, + required=False, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=False, + help="The filename of the wave to save the generated speech", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for baker_zh)", + ) + + return parser + + +def load_vocoder(checkpoint_path: Path) -> nn.Module: + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform( + mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module +) -> torch.Tensor: + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.squeeze() + + +def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: + text = split_text(text) + tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True) + + x = tokenizer.texts_to_token_ids([tokens]) + x = torch.tensor(x, dtype=torch.long, device=device) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesize( + model: nn.Module, + tokenizer: Tokenizer, + n_timesteps: int, + text: str, + length_scale: float, + temperature: float, + device: str = "cpu", + spks=None, +) -> dict: + text_processed = process_text(text=text, tokenizer=tokenizer, device=device) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + denoiser: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + texts = [c.supervisions[0].normalized_text for c in batch["cut"]] + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + for i in range(batch_size): + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=texts[i], + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", + data=output["waveform"], + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", + data=audio[i].numpy(), + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.inference_mode() +def main(): + parser = get_parser() + BakerZhTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + + # we need cut ids to organize tts results. + args.return_cuts = True + baker_zh = BakerZhTtsDataModule(args) + + test_cuts = baker_zh.test_cuts() + test_dl = baker_zh.test_dataloaders(test_cuts) + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + vocoder.to(device) + + denoiser = Denoiser(vocoder, mode="zeros") + denoiser.to(device) + + if params.input_text is not None and params.output_wav is not None: + logging.info("Synthesizing a single text") + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=params.input_text, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) + else: + logging.info("Decoding the test set") + infer_dataset( + dl=test_dl, + params=params, + model=model, + vocoder=vocoder, + denoiser=denoiser, + tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/matcha/tokenizer.py b/egs/baker_zh/TTS/matcha/tokenizer.py index d5c277ffe..a5000c877 100644 --- a/egs/baker_zh/TTS/matcha/tokenizer.py +++ b/egs/baker_zh/TTS/matcha/tokenizer.py @@ -80,7 +80,7 @@ class Tokenizer(object): token_ids = [] for t in tokens_list: if t not in self.token2id: - logging.warning(f"Skip OOV {t}") + logging.warning(f"Skip OOV {t} {sentence}") continue if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id: diff --git a/egs/baker_zh/TTS/matcha/train.py b/egs/baker_zh/TTS/matcha/train.py index 814cd1483..ed2ba49b9 100755 --- a/egs/baker_zh/TTS/matcha/train.py +++ b/egs/baker_zh/TTS/matcha/train.py @@ -315,7 +315,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids(tokens, intersperse_blank=True) + tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1]