switch to piper-phonemize

This commit is contained in:
Fangjun Kuang 2024-10-18 22:14:14 +08:00
parent 56d3b92f3f
commit 7077b4f99a
9 changed files with 746 additions and 60 deletions

4
egs/ljspeech/TTS/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
build
core.c
*.so
my-output*

View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LJSpeech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
load_manifest,
load_manifest_lazy,
)
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-jobs",
type=int,
default=4,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
return parser
def compute_fbank_ljspeech(num_jobs: int):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
if num_jobs < 1:
num_jobs = os.cpu_count()
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
sampling_rate = 22050
frame_length = 1024 / sampling_rate # (in second)
frame_shift = 256 / sampling_rate # (in second)
prefix = "ljspeech"
suffix = "jsonl.gz"
partition = "all"
recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
)
supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
)
# Differences with matcha-tts
# 1. we use pre-emphasis
# 2. we remove dc offset
# 3. we use a different window
# 4. we use a different mel filter bank matrix
# 5. we don't normalize features
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=frame_length,
frame_shift=frame_shift,
use_fft_mag=True,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ../matcha/train.py
num_filters=80,
)
extractor = Fbank(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{cuts_filename} already exists - skipping.")
return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=recordings, supervisions=supervisions
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_parser().parse_args()
compute_fbank_ljspeech(args.num_jobs)

View File

@ -28,17 +28,33 @@ try:
except ModuleNotFoundError as ex: except ModuleNotFoundError as ex:
raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n") raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n")
import argparse
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak from piper_phonemize import phonemize_espeak
def prepare_tokens_ljspeech(): def get_parser():
output_dir = Path("data/spectrogram") parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--in-out-dir",
type=Path,
required=True,
help="Input and output directory",
)
return parser
def prepare_tokens_ljspeech(in_out_dir):
prefix = "ljspeech" prefix = "ljspeech"
suffix = "jsonl.gz" suffix = "jsonl.gz"
partition = "all" partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") cut_set = load_manifest(in_out_dir / f"{prefix}_cuts_{partition}.{suffix}")
new_cuts = [] new_cuts = []
for cut in cut_set: for cut in cut_set:
@ -56,11 +72,13 @@ def prepare_tokens_ljspeech():
new_cuts.append(cut) new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts) new_cut_set = CutSet.from_cuts(new_cuts)
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") new_cut_set.to_file(in_out_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
prepare_tokens_ljspeech() args = get_parser().parse_args()
prepare_tokens_ljspeech(args.in_out_dir)

View File

@ -71,9 +71,12 @@ class MatchaTTS(torch.nn.Module): # 🍵
spk_emb_dim=spk_emb_dim, spk_emb_dim=spk_emb_dim,
) )
# self.update_data_statistics(data_statistics) if data_statistics is not None:
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
else:
self.register_buffer("mel_mean", torch.tensor(0.0))
self.register_buffer("mel_std", torch.tensor(1.0))
@torch.inference_mode() @torch.inference_mode()
def synthesise( def synthesise(

View File

@ -0,0 +1 @@
../vits/tokenizer.py

View File

@ -8,20 +8,24 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import k2
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from matcha.data.text_mel_datamodule import TextMelDataModule
from icefall.env import get_env_info
from matcha.models.matcha_tts import MatchaTTS from matcha.models.matcha_tts import MatchaTTS
from matcha.tokenizer import Tokenizer
from matcha.utils.model import fix_len_compatibility
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils2 import MetricsTracker, plot_feature from tts_datamodule import LJSpeechTtsDataModule
from utils2 import MetricsTracker
from icefall.checkpoint import load_checkpoint, save_checkpoint from icefall.checkpoint import load_checkpoint, save_checkpoint
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, setup_logger, str2bool from icefall.utils import AttributeDict, setup_logger, str2bool
@ -30,6 +34,20 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12335,
help="Master port to use for DDP training.",
)
parser.add_argument( parser.add_argument(
"--tensorboard", "--tensorboard",
type=str2bool, type=str2bool,
@ -64,6 +82,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -91,20 +116,14 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--batch-size",
type=int,
default=32,
)
return parser return parser
def get_data_statistics(): def get_data_statistics():
return AttributeDict( return AttributeDict(
{ {
"mel_mean": -5.517028331756592, "mel_mean": 0.0,
"mel_std": 2.0643954277038574, "mel_std": 1.0,
} }
) )
@ -141,7 +160,6 @@ def _get_model_params() -> AttributeDict:
encoder_params_p_dropout = 0.1 encoder_params_p_dropout = 0.1
params = AttributeDict( params = AttributeDict(
{ {
"n_vocab": 178,
"n_spks": 1, # for ljspeech. "n_spks": 1, # for ljspeech.
"spk_emb_dim": 64, "spk_emb_dim": 64,
"n_feats": n_feats, "n_feats": n_feats,
@ -216,8 +234,8 @@ def get_params():
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": -1, # 0 "batch_idx_train": -1, # 0
"log_interval": 50, "log_interval": 10,
"valid_interval": 2000, "valid_interval": 1500,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -271,9 +289,39 @@ def load_checkpoint_if_available(
return saved_params return saved_params
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
"""Parse batch data"""
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
max_feature_length = fix_len_compatibility(features.shape[1])
if max_feature_length > features.shape[1]:
pad = max_feature_length - features.shape[1]
features = torch.nn.functional.pad(features, (0, 0, 0, pad))
# features_lens[features_lens.argmax()] += pad
return audio, audio_lens, features, features_lens, tokens, tokens_lens
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -281,19 +329,35 @@ def compute_validation_loss(
"""Run the validation process.""" """Run the validation process."""
model.eval() model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to summary the stats over iterations # used to summary the stats over iterations
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
with torch.no_grad(): with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.to(device)
losses = model.get_losses(batch)
loss = sum(losses.values())
batch_size = batch["x"].shape[0] (
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device)
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
batch_size = len(batch["tokens"])
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
@ -324,6 +388,7 @@ def compute_validation_loss(
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
optimizer: Optimizer, optimizer: Optimizer,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
@ -356,6 +421,7 @@ def train_one_epoch(
""" """
model.train() model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to track the stats over iterations in one epoch # used to track the stats over iterations in one epoch
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
@ -374,20 +440,35 @@ def train_one_epoch(
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scaler=scaler, scaler=scaler,
rank=rank, rank=0,
) )
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
for key, value in batch.items(): # audio: (N, T), float32
if isinstance(value, torch.Tensor): # features: (N, T, C), float32
batch[key] = value.to(device) # audio_lens, (N,), int32
# features_lens, (N,), int32
# tokens: List[List[str]], len(tokens) == N
batch_size = batch["x"].shape[0] batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
try: try:
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
losses = model.get_losses(batch) losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
loss = sum(losses.values()) loss = sum(losses.values())
@ -458,6 +539,7 @@ def train_one_epoch(
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
tokenizer=tokenizer,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
@ -479,28 +561,31 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def main(): def run(rank, world_size, args):
parser = get_parser()
args = parser.parse_args()
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.data_args.batch_size = params.batch_size
del params.batch_size
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
print(f"Device: {device}")
print(f"Device: {device}") tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
logging.info(params) logging.info(params)
print(params) print(params)
@ -512,28 +597,35 @@ def main():
logging.info(f"Number of parameters: {num_param}") logging.info(f"Number of parameters: {num_param}")
print(f"Number of parameters: {num_param}") print(f"Number of parameters: {num_param}")
logging.info("About to create datamodule")
data_module = TextMelDataModule(hparams=params.data_args)
assert params.start_epoch > 0, params.start_epoch assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device) model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
logging.info("About to create datamodule")
ljspeech = LJSpeechTtsDataModule(args)
train_cuts = ljspeech.train_cuts()
train_dl = ljspeech.train_dataloaders(train_cuts)
valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
train_dl = data_module.train_dataloader()
valid_dl = data_module.val_dataloader()
rank = 0
for epoch in range(params.start_epoch, params.num_epochs + 1): for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}") logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch params.cur_epoch = epoch
@ -543,11 +635,14 @@ def main():
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
tokenizer=tokenizer,
optimizer=optimizer, optimizer=optimizer,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size,
rank=rank,
) )
if epoch % params.save_every_n == 0 or epoch == params.num_epochs: if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
@ -571,6 +666,23 @@ def main():
logging.info("Done!") logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)

View File

@ -0,0 +1,341 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LJSpeechTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
num_filters=80,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
)

View File

@ -3,3 +3,4 @@
# from matcha.utils.pylogger import get_pylogger # from matcha.utils.pylogger import get_pylogger
# from matcha.utils.rich_utils import enforce_tags, print_config_tree # from matcha.utils.rich_utils import enforce_tags, print_config_tree
# from matcha.utils.utils import extras, get_metric_value, task_wrapper # from matcha.utils.utils import extras, get_metric_value, task_wrapper
from matcha.utils.utils import intersperse

View File

@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=0 stage=-1
stop_stage=100 stop_stage=100
dl_dir=$PWD/download dl_dir=$PWD/download
@ -31,7 +31,19 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
python3 setup.py build_ext --inplace python3 setup.py build_ext --inplace
cd ../../ cd ../../
else else
log "monotonic_align lib already built" log "monotonic_align lib for vits already built"
fi
if [ ! -f ./matcha/utils/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then
pushd matcha/utils/monotonic_align
python3 setup.py build_ext --inplace
mv -v matcha/utils/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ./
rm -rf matcha
rm -rf build
rm core.c
popd
else
log "monotonic_align lib for matcha-tts already built"
fi fi
fi fi
@ -63,7 +75,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute spectrogram for LJSpeech" log "Stage 2: Compute spectrogram for LJSpeech (used by ./vits)"
mkdir -p data/spectrogram mkdir -p data/spectrogram
if [ ! -e data/spectrogram/.ljspeech.done ]; then if [ ! -e data/spectrogram/.ljspeech.done ]; then
./local/compute_spectrogram_ljspeech.py ./local/compute_spectrogram_ljspeech.py
@ -71,7 +83,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi fi
if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then
log "Validating data/spectrogram for LJSpeech" log "Validating data/spectrogram for LJSpeech (used by ./vits)"
python3 ./local/validate_manifest.py \ python3 ./local/validate_manifest.py \
data/spectrogram/ljspeech_cuts_all.jsonl.gz data/spectrogram/ljspeech_cuts_all.jsonl.gz
touch data/spectrogram/.ljspeech-validated.done touch data/spectrogram/.ljspeech-validated.done
@ -79,13 +91,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech" log "Stage 3: Prepare phoneme tokens for LJSpeech (used by ./vits)"
# We assume you have installed piper_phonemize and espnet_tts_frontend. # We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with: # If not, please install them with:
# - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html,
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/spectrogram
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
data/spectrogram/ljspeech_cuts_all.jsonl.gz data/spectrogram/ljspeech_cuts_all.jsonl.gz
touch data/spectrogram/.ljspeech_with_token.done touch data/spectrogram/.ljspeech_with_token.done
@ -93,7 +105,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" log "Stage 4: Split the LJSpeech cuts into train, valid and test sets (used by vits)"
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
lhotse subset --last 600 \ lhotse subset --last 600 \
data/spectrogram/ljspeech_cuts_all.jsonl.gz \ data/spectrogram/ljspeech_cuts_all.jsonl.gz \
@ -126,3 +138,56 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./local/prepare_token_file.py --tokens data/tokens.txt ./local/prepare_token_file.py --tokens data/tokens.txt
fi fi
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Generate fbank (used by ./matcha)"
mkdir -p data/fbank
if [ ! -e data/fbank/.ljspeech.done ]; then
./local/compute_fbank_ljspeech.py
touch data/fbank/.ljspeech.done
fi
if [ ! -e data/fbank/.ljspeech-validated.done ]; then
log "Validating data/fbank for LJSpeech (used by ./matcha)"
python3 ./local/validate_manifest.py \
data/fbank/ljspeech_cuts_all.jsonl.gz
touch data/fbank/.ljspeech-validated.done
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare phoneme tokens for LJSpeech (used by ./matcha)"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html,
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/fbank/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py --in-out-dir ./data/fbank
mv data/fbank/ljspeech_cuts_with_tokens_all.jsonl.gz \
data/fbank/ljspeech_cuts_all.jsonl.gz
touch data/fbank/.ljspeech_with_token.done
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Split the LJSpeech cuts into train, valid and test sets (used by ./matcha)"
if [ ! -e data/fbank/.ljspeech_split.done ]; then
lhotse subset --last 600 \
data/fbank/ljspeech_cuts_all.jsonl.gz \
data/fbank/ljspeech_cuts_validtest.jsonl.gz
lhotse subset --first 100 \
data/fbank/ljspeech_cuts_validtest.jsonl.gz \
data/fbank/ljspeech_cuts_valid.jsonl.gz
lhotse subset --last 500 \
data/fbank/ljspeech_cuts_validtest.jsonl.gz \
data/fbank/ljspeech_cuts_test.jsonl.gz
rm data/fbank/ljspeech_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
lhotse subset --first $n \
data/fbank/ljspeech_cuts_all.jsonl.gz \
data/fbank/ljspeech_cuts_train.jsonl.gz
touch data/fbank/.ljspeech_split.done
fi
fi