ready to train

This commit is contained in:
Fangjun Kuang 2024-12-26 11:51:56 +08:00
parent f4d6fb06aa
commit e4d40baaf5
11 changed files with 1633 additions and 2 deletions

View File

@ -0,0 +1 @@
../matcha/audio.py

View File

@ -0,0 +1,110 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the baker-zh dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-jobs",
type=int,
default=4,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
return parser
def compute_fbank_baker_zh(num_jobs: int):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
if num_jobs < 1:
num_jobs = os.cpu_count()
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
prefix = "baker_zh"
suffix = "jsonl.gz"
extractor = MatchaFbank(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts.{suffix}"
logging.info(f"Processing {cuts_filename}")
cut_set = load_manifest(src_dir / cuts_filename).resample(22050)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_parser().parse_args()
compute_fbank_baker_zh(args.num_jobs)

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script compute the mean and std of the fbank features.
"""
import argparse
import json
import logging
from pathlib import Path
import torch
from lhotse import CutSet, load_manifest_lazy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
parser.add_argument(
"cmvn",
type=Path,
help="Path to the cmvn.json",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(
f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}"
)
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet), type(cut_set)
feat_dim = cut_set[0].features.num_features
num_frames = 0
s = 0
sq = 0
for c in cut_set:
f = torch.from_numpy(c.load_features())
num_frames += f.shape[0]
s += f.sum()
sq += f.square().sum()
fbank_mean = s / (num_frames * feat_dim)
fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean
print("fbank var", fbank_var)
fbank_std = fbank_var.sqrt()
with open(args.cmvn, "w") as f:
json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f)
f.write("\n")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,119 @@
#!/usr/bin/env python3
import argparse
import re
from typing import List
import jieba
from lhotse import load_manifest
from pypinyin import lazy_pinyin, load_phrases_dict, Style
load_phrases_dict(
{
"行长": [["hang2"], ["zhang3"]],
"银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]],
}
)
whiter_space_re = re.compile(r"\s+")
punctuations_re = [
(re.compile(x[0], re.IGNORECASE), x[1])
for x in [
("", ","),
("", "."),
("", "!"),
("", "?"),
("", '"'),
("", '"'),
("", "'"),
("", "'"),
("", ":"),
("", ","),
]
]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--in-file",
type=str,
required=True,
help="Input cutset.",
)
parser.add_argument(
"--out-file",
type=str,
required=True,
help="Output cutset.",
)
return parser
def normalize_white_spaces(text):
return whiter_space_re.sub(" ", text)
def normalize_punctuations(text):
for regex, replacement in punctuations_re:
text = re.sub(regex, replacement, text)
return text
def split_text(text: str) -> List[str]:
"""
Example input: '你好呀You are 一个好人。 去银行存钱How about you?'
Example output: ['你好', '', ',', 'you are', '一个', '好人', '.', '', '银行', '存钱', '?', 'how about you', '?']
"""
text = text.lower()
text = normalize_white_spaces(text)
text = normalize_punctuations(text)
ans = []
for seg in jieba.cut(text):
if seg in ",.!?:\"'":
ans.append(seg)
elif seg == " " and len(ans) > 0:
if ord("a") <= ord(ans[-1][-1]) <= ord("z"):
ans[-1] += seg
elif ord("a") <= ord(seg[0]) <= ord("z"):
if len(ans) == 0:
ans.append(seg)
continue
if ans[-1][-1] == " ":
ans[-1] += seg
continue
ans.append(seg)
else:
ans.append(seg)
ans = [s.strip() for s in ans]
return ans
def main():
args = get_parser().parse_args()
cuts = load_manifest(args.in_file)
for c in cuts:
assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions)
text = c.supervisions[0].normalized_text
text_list = split_text(text)
tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True)
c.supervisions[0].tokens = tokens
cuts.to_file(args.out_file)
print(f"saved to {args.out_file}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../matcha/fbank.py

6
egs/baker_zh/TTS/local/generate_tokens.py Normal file → Executable file
View File

@ -46,9 +46,13 @@ def generate_token_list() -> List[str]:
ans = list(token_set)
ans.sort()
punctuations = list(",.!?:\"'")
ans = punctuations + ans
# use ID 0 for blank
# We use blank for padding
# Use ID 1 of _ for padding
ans.insert(0, " ")
ans.insert(1, "_") #
return ans

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/spectrogram/baker_zh_cuts_all.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet), type(cut_set)
validate_for_tts(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/matcha/tokenizer.py

View File

@ -0,0 +1,119 @@
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import logging
from typing import Dict, List
import tacotron_cleaner.cleaners
try:
from piper_phonemize import phonemize_espeak
except Exception as ex:
raise RuntimeError(
f"{ex}\nPlease run\n"
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
)
from utils import intersperse
# This tokenizer supports both English and Chinese.
# We assume you have used
# ../local/convert_text_to_tokens.py
# to process your text
class Tokenizer(object):
def __init__(self, tokens: str):
"""
Args:
tokens: the file that maps tokens to ids
"""
# Parse token file
self.token2id: Dict[str, int] = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
id = int(info[0])
else:
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
self.pad_id = self.token2id["_"] # padding
self.space_id = self.token2id[" "] # word separator (whitespace)
self.vocab_size = len(self.token2id)
def texts_to_token_ids(
self,
sentence_list: List[List[str]],
intersperse_blank: bool = True,
lang: str = "en-us",
) -> List[List[int]]:
"""
Args:
sentence_list:
A list of sentences.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
lang:
Language argument passed to phonemize_espeak().
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for sentence in sentence_list:
tokens_list = []
for word in sentence:
if word in self.token2id:
tokens_list.append(word)
continue
tmp_tokens_list = phonemize_espeak(word, lang)
for t in tmp_tokens_list:
tokens_list.extend(t)
token_ids = []
for t in tokens_list:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id:
continue
token_ids.append(self.token2id[t])
if intersperse_blank:
token_ids = intersperse(token_ids, self.pad_id)
token_ids_list.append(token_ids)
return token_ids_list
def test_tokenizer():
import jieba
from pypinyin import lazy_pinyin, Style
tokenizer = Tokenizer("data/tokens.txt")
text1 = "今天is Monday, tomorrow is 星期二"
text2 = "你好吗? 我很好, how about you?"
text1 = list(jieba.cut(text1))
text2 = list(jieba.cut(text2))
tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True)
tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True)
print(tokens1)
print(tokens2)
ids = tokenizer.texts_to_token_ids([tokens1, tokens2])
print(ids)
if __name__ == "__main__":
test_tokenizer()

717
egs/baker_zh/TTS/matcha/train.py Executable file
View File

@ -0,0 +1,717 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import json
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.utils import fix_random_seed
from model import fix_len_compatibility
from models.matcha_tts import MatchaTTS
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import BakerZhTtsDataModule
from utils import MetricsTracker
from icefall.checkpoint import load_checkpoint, save_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12335,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=1000,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--save-every-n",
type=int,
default=10,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
Since it will take around 1000 epochs, we suggest using a large
save_every_n to save disk space.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
def get_data_statistics():
return AttributeDict(
{
"mel_mean": 0,
"mel_std": 1,
}
)
def _get_data_params() -> AttributeDict:
params = AttributeDict(
{
"name": "baker-zh",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
# "batch_size": 64,
# "num_workers": 1,
# "pin_memory": False,
"cleaners": ["english_cleaners2"],
"add_blank": True,
"n_spks": 1,
"n_fft": 1024,
"n_feats": 80,
"sampling_rate": 22050,
"hop_length": 256,
"win_length": 1024,
"f_min": 0,
"f_max": 8000,
"seed": 1234,
"load_durations": False,
"data_statistics": get_data_statistics(),
}
)
return params
def _get_model_params() -> AttributeDict:
n_feats = 80
filter_channels_dp = 256
encoder_params_p_dropout = 0.1
params = AttributeDict(
{
"n_spks": 1, # for baker-zh.
"spk_emb_dim": 64,
"n_feats": n_feats,
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"data_statistics": get_data_statistics(),
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
"encoder_params": AttributeDict(
{
"n_feats": n_feats,
"n_channels": 192,
"filter_channels": 768,
"filter_channels_dp": filter_channels_dp,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
"spk_emb_dim": 64,
"n_spks": 1,
"prenet": True,
}
),
"duration_predictor_params": AttributeDict(
{
"filter_channels_dp": filter_channels_dp,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
}
),
}
),
"decoder": AttributeDict(
{
"channels": [256, 256],
"dropout": 0.05,
"attention_head_dim": 64,
"n_blocks": 1,
"num_mid_blocks": 2,
"num_heads": 2,
"act_fn": "snakebeta",
}
),
"cfm": AttributeDict(
{
"name": "CFM",
"solver": "euler",
"sigma_min": 1e-4,
}
),
"optimizer": AttributeDict(
{
"lr": 1e-4,
"weight_decay": 0.0,
}
),
}
)
return params
def get_params():
params = AttributeDict(
{
"model_args": _get_model_params(),
"data_args": _get_data_params(),
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 10,
"valid_interval": 1500,
"env_info": get_env_info(),
}
)
return params
def get_model(params):
m = MatchaTTS(**params.model_args)
return m
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
"""Parse batch data"""
mel_mean = params.data_args.data_statistics.mel_mean
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
for i in range(batch["features"].shape[0]):
n = batch["features_lens"][i]
batch["features"][i : i + 1, :n, :] = (
batch["features"][i : i + 1, :n, :] - mel_mean
) * mel_std_inv
batch["features"][i : i + 1, n:, :] = 0
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens, intersperse_blank=True)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
max_feature_length = fix_len_compatibility(features.shape[1])
if max_feature_length > features.shape[1]:
pad = max_feature_length - features.shape[1]
features = torch.nn.functional.pad(features, (0, 0, 0, pad))
# features_lens[features_lens.argmax()] += pad
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to summary the stats over iterations
tot_loss = MetricsTracker()
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
batch_size = len(batch["tokens"])
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
# summary stats
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
optimizer: Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer=optimizer,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
# audio: (N, T), float32
# features: (N, T, C), float32
# audio_lens, (N,), int32
# features_lens, (N,), int32
# tokens: List[List[str]], len(tokens) == N
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
try:
with autocast(enabled=params.use_fp16):
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
loss = sum(losses.values())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it.
# The _growth_interval of the grad scaler is configurable,
# but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, "
f"batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if params.batch_idx_train % params.valid_interval == 1:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
rank=rank,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
"Maximum memory allocated so far is "
f"{torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
print(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of parameters: {num_param}")
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
logging.info("About to create datamodule")
baker_zh = BakerZhTtsDataModule(args)
train_cuts = baker_zh.train_cuts()
train_dl = baker_zh.train_dataloaders(train_cuts)
valid_cuts = baker_zh.valid_cuts()
valid_dl = baker_zh.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
if "sampler" in train_dl:
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
tokenizer=tokenizer,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer=optimizer,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
BakerZhTtsDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1,340 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class BakerZhTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=True,
pin_memory=True,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz"
)

View File

@ -82,3 +82,70 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
python3 ./local/generate_tokens.py --tokens data/tokens.txt
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Generate raw cutset"
if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then
lhotse cut simple \
-r ./data/manifests/baker_zh_recordings_all.jsonl.gz \
-s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \
./data/manifests/baker_zh_cuts_raw.jsonl.gz
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Convert text to tokens"
if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then
python3 ./local/convert_text_to_tokens.py \
--in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \
--out-file ./data/manifests/baker_zh_cuts.jsonl.gz
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate fbank (used by ./matcha)"
mkdir -p data/fbank
if [ ! -e data/fbank/.baker-zh.done ]; then
./local/compute_fbank_baker_zh.py
touch data/fbank/.baker-zh.done
fi
if [ ! -e data/fbank/.baker-zh-validated.done ]; then
log "Validating data/fbank for baker-zh (used by ./matcha)"
python3 ./local/validate_manifest.py \
data/fbank/baker_zh_cuts.jsonl.gz
touch data/fbank/.baker-zh-validated.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)"
if [ ! -e data/fbank/.baker_zh_split.done ]; then
lhotse subset --last 600 \
data/fbank/baker_zh_cuts.jsonl.gz \
data/fbank/baker_zh_cuts_validtest.jsonl.gz
lhotse subset --first 100 \
data/fbank/baker_zh_cuts_validtest.jsonl.gz \
data/fbank/baker_zh_cuts_valid.jsonl.gz
lhotse subset --last 500 \
data/fbank/baker_zh_cuts_validtest.jsonl.gz \
data/fbank/baker_zh_cuts_test.jsonl.gz
rm data/fbank/baker_zh_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 ))
lhotse subset --first $n \
data/fbank/baker_zh_cuts.jsonl.gz \
data/fbank/baker_zh_cuts_train.jsonl.gz
touch data/fbank/.baker_zh_split.done
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 6: Compute fbank mean and std (used by ./matcha)"
if [ ! -f ./data/fbank/cmvn.json ]; then
./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json
fi
fi