diff --git a/egs/ljspeech/TTS/.gitignore b/egs/ljspeech/TTS/.gitignore new file mode 100644 index 000000000..1eef06a28 --- /dev/null +++ b/egs/ljspeech/TTS/.gitignore @@ -0,0 +1,4 @@ +build +core.c +*.so +my-output* diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py new file mode 100755 index 000000000..3aeb6add7 --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + load_manifest, + load_manifest_lazy, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=4, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + return parser + + +def compute_fbank_ljspeech(num_jobs: int): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + if num_jobs < 1: + num_jobs = os.cpu_count() + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + # Differences with matcha-tts + # 1. we use pre-emphasis + # 2. we remove dc offset + # 3. we use a different window + # 4. we use a different mel filter bank matrix + # 5. we don't normalize features + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=True, + low_freq=0, + high_freq=8000, + # should be identical to n_feats in ../matcha/train.py + num_filters=80, + ) + extractor = Fbank(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank_ljspeech(args.num_jobs) diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py index 4ba88604c..33a8ac2ab 100755 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -28,17 +28,33 @@ try: except ModuleNotFoundError as ex: raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n") +import argparse + from lhotse import CutSet, load_manifest from piper_phonemize import phonemize_espeak -def prepare_tokens_ljspeech(): - output_dir = Path("data/spectrogram") +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--in-out-dir", + type=Path, + required=True, + help="Input and output directory", + ) + + return parser + + +def prepare_tokens_ljspeech(in_out_dir): prefix = "ljspeech" suffix = "jsonl.gz" partition = "all" - cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + cut_set = load_manifest(in_out_dir / f"{prefix}_cuts_{partition}.{suffix}") new_cuts = [] for cut in cut_set: @@ -56,11 +72,13 @@ def prepare_tokens_ljspeech(): new_cuts.append(cut) new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + new_cut_set.to_file(in_out_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - prepare_tokens_ljspeech() + args = get_parser().parse_args() + + prepare_tokens_ljspeech(args.in_out_dir) diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index d5d78c619..b1525695f 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -71,9 +71,12 @@ class MatchaTTS(torch.nn.Module): # 🍵 spk_emb_dim=spk_emb_dim, ) - # self.update_data_statistics(data_statistics) - self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) - self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + if data_statistics is not None: + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + else: + self.register_buffer("mel_mean", torch.tensor(0.0)) + self.register_buffer("mel_std", torch.tensor(1.0)) @torch.inference_mode() def synthesise( diff --git a/egs/ljspeech/TTS/matcha/tokenizer.py b/egs/ljspeech/TTS/matcha/tokenizer.py new file mode 120000 index 000000000..44a19b0f4 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/tokenizer.py @@ -0,0 +1 @@ +../vits/tokenizer.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 385dcba23..94e089d7e 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -8,20 +8,24 @@ from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Union +import k2 import torch +import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed -from matcha.data.text_mel_datamodule import TextMelDataModule -from icefall.env import get_env_info from matcha.models.matcha_tts import MatchaTTS +from matcha.tokenizer import Tokenizer +from matcha.utils.model import fix_len_compatibility from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from utils2 import MetricsTracker, plot_feature +from tts_datamodule import LJSpeechTtsDataModule +from utils2 import MetricsTracker from icefall.checkpoint import load_checkpoint, save_checkpoint from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info from icefall.utils import AttributeDict, setup_logger, str2bool @@ -30,6 +34,20 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12335, + help="Master port to use for DDP training.", + ) + parser.add_argument( "--tensorboard", type=str2bool, @@ -64,6 +82,13 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + parser.add_argument( "--seed", type=int, @@ -91,20 +116,14 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--batch-size", - type=int, - default=32, - ) - return parser def get_data_statistics(): return AttributeDict( { - "mel_mean": -5.517028331756592, - "mel_std": 2.0643954277038574, + "mel_mean": 0.0, + "mel_std": 1.0, } ) @@ -141,7 +160,6 @@ def _get_model_params() -> AttributeDict: encoder_params_p_dropout = 0.1 params = AttributeDict( { - "n_vocab": 178, "n_spks": 1, # for ljspeech. "spk_emb_dim": 64, "n_feats": n_feats, @@ -216,8 +234,8 @@ def get_params(): "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": -1, # 0 - "log_interval": 50, - "valid_interval": 2000, + "log_interval": 10, + "valid_interval": 1500, "env_info": get_env_info(), } ) @@ -271,9 +289,39 @@ def load_checkpoint_if_available( return saved_params +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + max_feature_length = fix_len_compatibility(features.shape[1]) + if max_feature_length > features.shape[1]: + pad = max_feature_length - features.shape[1] + features = torch.nn.functional.pad(features, (0, 0, 0, pad)) + + # features_lens[features_lens.argmax()] += pad + + return audio, audio_lens, features, features_lens, tokens, tokens_lens + + def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], + tokenizer: Tokenizer, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, rank: int = 0, @@ -281,19 +329,35 @@ def compute_validation_loss( """Run the validation process.""" model.eval() device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses # used to summary the stats over iterations tot_loss = MetricsTracker() with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - for key, value in batch.items(): - if isinstance(value, torch.Tensor): - batch[key] = value.to(device) - losses = model.get_losses(batch) - loss = sum(losses.values()) - batch_size = batch["x"].shape[0] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device) + + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + batch_size = len(batch["tokens"]) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -324,6 +388,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], + tokenizer: Tokenizer, optimizer: Optimizer, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -356,6 +421,7 @@ def train_one_epoch( """ model.train() device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses # used to track the stats over iterations in one epoch tot_loss = MetricsTracker() @@ -374,20 +440,35 @@ def train_one_epoch( params=params, optimizer=optimizer, scaler=scaler, - rank=rank, + rank=0, ) for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 - for key, value in batch.items(): - if isinstance(value, torch.Tensor): - batch[key] = value.to(device) + # audio: (N, T), float32 + # features: (N, T, C), float32 + # audio_lens, (N,), int32 + # features_lens, (N,), int32 + # tokens: List[List[str]], len(tokens) == N - batch_size = batch["x"].shape[0] + batch_size = len(batch["tokens"]) + + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) try: with autocast(enabled=params.use_fp16): - losses = model.get_losses(batch) + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) loss = sum(losses.values()) @@ -458,6 +539,7 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, + tokenizer=tokenizer, valid_dl=valid_dl, world_size=world_size, rank=rank, @@ -479,28 +561,31 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def main(): - parser = get_parser() - args = parser.parse_args() +def run(rank, world_size, args): params = get_params() - params.update(vars(args)) - params.data_args.batch_size = params.batch_size - del params.batch_size - fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", 0) + device = torch.device("cuda", rank) logging.info(f"Device: {device}") - print(f"Device: {device}") - print(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size logging.info(params) print(params) @@ -512,28 +597,35 @@ def main(): logging.info(f"Number of parameters: {num_param}") print(f"Number of parameters: {num_param}") - logging.info("About to create datamodule") - data_module = TextMelDataModule(hparams=params.data_args) - assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + logging.info("About to create datamodule") + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - train_dl = data_module.train_dataloader() - valid_dl = data_module.val_dataloader() - - rank = 0 - for epoch in range(params.start_epoch, params.num_epochs + 1): logging.info(f"Start epoch {epoch}") fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) params.cur_epoch = epoch @@ -543,11 +635,14 @@ def main(): train_one_epoch( params=params, model=model, + tokenizer=tokenizer, optimizer=optimizer, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, tb_writer=tb_writer, + world_size=world_size, + rank=rank, ) if epoch % params.save_every_n == 0 or epoch == params.num_epochs: @@ -571,6 +666,23 @@ def main(): logging.info("Done!") + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + torch.set_num_threads(1) torch.set_num_interop_threads(1) diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py new file mode 100644 index 000000000..c2be815d9 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -0,0 +1,341 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + low_freq=0, + high_freq=8000, + # should be identical to n_feats in ./train.py + num_filters=80, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + low_freq=0, + high_freq=8000, + # should be identical to n_feats in ./train.py + num_filters=80, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + low_freq=0, + high_freq=8000, + # should be identical to n_feats in ./train.py + num_filters=80, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/TTS/matcha/utils/__init__.py b/egs/ljspeech/TTS/matcha/utils/__init__.py index 2b74b40f5..311744a78 100644 --- a/egs/ljspeech/TTS/matcha/utils/__init__.py +++ b/egs/ljspeech/TTS/matcha/utils/__init__.py @@ -3,3 +3,4 @@ # from matcha.utils.pylogger import get_pylogger # from matcha.utils.rich_utils import enforce_tags, print_config_tree # from matcha.utils.utils import extras, get_metric_value, task_wrapper +from matcha.utils.utils import intersperse diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 9ed0f93fd..e1cd0897e 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -stage=0 +stage=-1 stop_stage=100 dl_dir=$PWD/download @@ -31,7 +31,19 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then python3 setup.py build_ext --inplace cd ../../ else - log "monotonic_align lib already built" + log "monotonic_align lib for vits already built" + fi + + if [ ! -f ./matcha/utils/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then + pushd matcha/utils/monotonic_align + python3 setup.py build_ext --inplace + mv -v matcha/utils/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ./ + rm -rf matcha + rm -rf build + rm core.c + popd + else + log "monotonic_align lib for matcha-tts already built" fi fi @@ -63,7 +75,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute spectrogram for LJSpeech" + log "Stage 2: Compute spectrogram for LJSpeech (used by ./vits)" mkdir -p data/spectrogram if [ ! -e data/spectrogram/.ljspeech.done ]; then ./local/compute_spectrogram_ljspeech.py @@ -71,7 +83,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then - log "Validating data/spectrogram for LJSpeech" + log "Validating data/spectrogram for LJSpeech (used by ./vits)" python3 ./local/validate_manifest.py \ data/spectrogram/ljspeech_cuts_all.jsonl.gz touch data/spectrogram/.ljspeech-validated.done @@ -79,13 +91,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare phoneme tokens for LJSpeech" + log "Stage 3: Prepare phoneme tokens for LJSpeech (used by ./vits)" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then - ./local/prepare_tokens_ljspeech.py + ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/spectrogram mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ data/spectrogram/ljspeech_cuts_all.jsonl.gz touch data/spectrogram/.ljspeech_with_token.done @@ -93,7 +105,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" + log "Stage 4: Split the LJSpeech cuts into train, valid and test sets (used by vits)" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then lhotse subset --last 600 \ data/spectrogram/ljspeech_cuts_all.jsonl.gz \ @@ -126,3 +138,56 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./local/prepare_token_file.py --tokens data/tokens.txt fi fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate fbank (used by ./matcha)" + mkdir -p data/fbank + if [ ! -e data/fbank/.ljspeech.done ]; then + ./local/compute_fbank_ljspeech.py + touch data/fbank/.ljspeech.done + fi + + if [ ! -e data/fbank/.ljspeech-validated.done ]; then + log "Validating data/fbank for LJSpeech (used by ./matcha)" + python3 ./local/validate_manifest.py \ + data/fbank/ljspeech_cuts_all.jsonl.gz + touch data/fbank/.ljspeech-validated.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare phoneme tokens for LJSpeech (used by ./matcha)" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/fbank/.ljspeech_with_token.done ]; then + ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/fbank + mv data/fbank/ljspeech_cuts_with_tokens_all.jsonl.gz \ + data/fbank/ljspeech_cuts_all.jsonl.gz + touch data/fbank/.ljspeech_with_token.done + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Split the LJSpeech cuts into train, valid and test sets (used by ./matcha)" + if [ ! -e data/fbank/.ljspeech_split.done ]; then + lhotse subset --last 600 \ + data/fbank/ljspeech_cuts_all.jsonl.gz \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz \ + data/fbank/ljspeech_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz \ + data/fbank/ljspeech_cuts_test.jsonl.gz + + rm data/fbank/ljspeech_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/fbank/ljspeech_cuts_all.jsonl.gz \ + data/fbank/ljspeech_cuts_train.jsonl.gz + touch data/fbank/.ljspeech_split.done + fi +fi