2024-10-29 15:04:04 +08:00

724 lines
22 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import json
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.utils import fix_random_seed
from matcha.model import fix_len_compatibility
from matcha.models.matcha_tts import MatchaTTS
from matcha.tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker
from icefall.checkpoint import load_checkpoint, save_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12335,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=1000,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--save-every-n",
type=int,
default=10,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
Since it will take around 1000 epochs, we suggest using a large
save_every_n to save disk space.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
def get_data_statistics():
return AttributeDict(
{
"mel_mean": 0,
"mel_std": 1,
}
)
def _get_data_params() -> AttributeDict:
params = AttributeDict(
{
"name": "ljspeech",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
# "batch_size": 64,
# "num_workers": 1,
# "pin_memory": False,
"cleaners": ["english_cleaners2"],
"add_blank": True,
"n_spks": 1,
"n_fft": 1024,
"n_feats": 80,
"sample_rate": 22050,
"hop_length": 256,
"win_length": 1024,
"f_min": 0,
"f_max": 8000,
"seed": 1234,
"load_durations": False,
"data_statistics": get_data_statistics(),
}
)
return params
def _get_model_params() -> AttributeDict:
n_feats = 80
filter_channels_dp = 256
encoder_params_p_dropout = 0.1
params = AttributeDict(
{
"n_spks": 1, # for ljspeech.
"spk_emb_dim": 64,
"n_feats": n_feats,
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"data_statistics": get_data_statistics(),
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
"encoder_params": AttributeDict(
{
"n_feats": n_feats,
"n_channels": 192,
"filter_channels": 768,
"filter_channels_dp": filter_channels_dp,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
"spk_emb_dim": 64,
"n_spks": 1,
"prenet": True,
}
),
"duration_predictor_params": AttributeDict(
{
"filter_channels_dp": filter_channels_dp,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
}
),
}
),
"decoder": AttributeDict(
{
"channels": [256, 256],
"dropout": 0.05,
"attention_head_dim": 64,
"n_blocks": 1,
"num_mid_blocks": 2,
"num_heads": 2,
"act_fn": "snakebeta",
}
),
"cfm": AttributeDict(
{
"name": "CFM",
"solver": "euler",
"sigma_min": 1e-4,
}
),
"optimizer": AttributeDict(
{
"lr": 1e-4,
"weight_decay": 0.0,
}
),
}
)
return params
def get_params():
params = AttributeDict(
{
"model_args": _get_model_params(),
"data_args": _get_data_params(),
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 10,
"valid_interval": 1500,
"env_info": get_env_info(),
}
)
return params
def get_model(params):
m = MatchaTTS(**params.model_args)
return m
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
"""Parse batch data"""
mel_mean = params.data_args.data_statistics.mel_mean
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
for i in range(batch["features"].shape[0]):
n = batch["features_lens"][i]
batch["features"][i : i + 1, :n, :] = (
batch["features"][i : i + 1, :n, :] - mel_mean
) * mel_std_inv
batch["features"][i : i + 1, n:, :] = 0
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
max_feature_length = fix_len_compatibility(features.shape[1])
if max_feature_length > features.shape[1]:
pad = max_feature_length - features.shape[1]
features = torch.nn.functional.pad(features, (0, 0, 0, pad))
# features_lens[features_lens.argmax()] += pad
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to summary the stats over iterations
tot_loss = MetricsTracker()
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
batch_size = len(batch["tokens"])
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
# summary stats
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
optimizer: Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer=optimizer,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
# audio: (N, T), float32
# features: (N, T, C), float32
# audio_lens, (N,), int32
# features_lens, (N,), int32
# tokens: List[List[str]], len(tokens) == N
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
try:
with autocast(enabled=params.use_fp16):
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
loss = sum(losses.values())
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it.
# The _growth_interval of the grad scaler is configurable,
# but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, "
f"batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if params.batch_idx_train % params.valid_interval == 1:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
rank=rank,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
"Maximum memory allocated so far is "
f"{torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
print(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of parameters: {num_param}")
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
logging.info("About to create datamodule")
ljspeech = LJSpeechTtsDataModule(args)
train_cuts = ljspeech.train_cuts()
train_dl = ljspeech.train_dataloaders(train_cuts)
valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
if "sampler" in train_dl:
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
tokenizer=tokenizer,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer=optimizer,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()