2024-12-13 19:39:55 +08:00

1055 lines
33 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import itertools
import json
import copy
import math
import os
import random
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LibriTTSDataModule
from torch.optim.lr_scheduler import ExponentialLR, LRScheduler
from torch.optim import Optimizer
from utils import (
load_checkpoint,
save_checkpoint,
plot_spectrogram,
get_cosine_schedule_with_warmup,
save_checkpoint_with_global_batch_idx,
)
from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints, update_averaged_model
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
get_parameter_groups_with_lrs,
)
from model import Vocos
from lhotse import Fbank, FbankConfig
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-layers",
type=int,
default=8,
help="Number of ConvNeXt layers.",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=512,
help="Hidden dim of ConvNeXt module.",
)
parser.add_argument(
"--intermediate-dim",
type=int,
default=1536,
help="Intermediate dim of ConvNeXt module.",
)
parser.add_argument(
"--max-seconds",
type=int,
default=60,
help="""
The length of the precomputed normalization window sum square
(required by istft). This argument is only for onnx export, it determines
the max length of the audio that be properly normalized.
Note, you can generate audios longer than this value with the exported onnx model,
the part longer than this value will not be normalized yet.
The larger this value is the bigger the exported onnx model will be.
""",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=100,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vocos/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--learning-rate", type=float, default=0.0005, help="The learning rate."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--save-every-n",
type=int,
default=4000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=30,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
parser.add_argument(
"--keep-last-epoch-k",
type=int,
default=50,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `epoch-xxx.pt`.
""",
)
parser.add_argument(
"--average-period",
type=int,
default=500,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
parser.add_argument(
"--mrd-loss-scale",
type=float,
default=0.1,
help="The scale of MultiResolutionDiscriminator loss.",
)
parser.add_argument(
"--mel-loss-scale",
type=float,
default=45,
help="The scale of melspectrogram loss.",
)
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 500,
"feature_dim": 80,
"segment_size": 16384,
"adam_b1": 0.9,
"adam_b2": 0.99,
"warmup_steps": 0,
"max_steps": 2000000,
"env_info": get_env_info(),
}
)
return params
def get_model(params: AttributeDict) -> nn.Module:
device = params.device
model = Vocos(
feature_dim=params.feature_dim,
dim=params.hidden_dim,
n_fft=params.frame_length,
hop_length=params.frame_shift,
intermediate_dim=params.intermediate_dim,
num_layers=params.num_layers,
sample_rate=params.sampling_rate,
max_seconds=params.max_seconds,
).to(device)
num_param_gen = sum([p.numel() for p in model.generator.parameters()])
logging.info(f"Number of Generator parameters : {num_param_gen}")
num_param_mpd = sum([p.numel() for p in model.mpd.parameters()])
logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}")
num_param_mrd = sum([p.numel() for p in model.mrd.parameters()])
logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}")
logging.info(
f"Number of model parameters : {num_param_gen + num_param_mpd + num_param_mrd}"
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
model_avg: nn.Module = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRScheduler] = None,
scheduler_d: Optional[LRScheduler] = None,
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
model_avg:
The stored model averaged from the start of training.
optimizer:
The optimizer that we are using.
scheduler:
The scheduler that we are using.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(
filename,
model=model,
model_avg=model_avg,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
return saved_params
def compute_generator_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
features: Tensor,
audios: Tensor,
) -> Tuple[Tensor, MetricsTracker]:
device = params.device
model = model.module if isinstance(model, DDP) else model
audios_hat = model(features) # (B, T)
mel_loss = model.melspec_loss(audios_hat, audios)
_, gen_score_mpd, fmap_rs_mpd, fmap_gs_mpd = model.mpd(y=audios, y_hat=audios_hat)
_, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = model.mrd(y=audios, y_hat=audios_hat)
loss_gen_mpd, list_loss_gen_mpd = model.gen_loss(disc_outputs=gen_score_mpd)
loss_gen_mrd, list_loss_gen_mrd = model.gen_loss(disc_outputs=gen_score_mrd)
loss_gen_mpd = loss_gen_mpd / len(list_loss_gen_mpd)
loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
loss_fm_mpd = model.feat_matching_loss(
fmap_r=fmap_rs_mpd, fmap_g=fmap_gs_mpd
) / len(fmap_rs_mpd)
loss_fm_mrd = model.feat_matching_loss(
fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd
) / len(fmap_rs_mrd)
loss_gen_all = (
loss_gen_mpd
+ params.mrd_loss_scale * loss_gen_mrd
+ loss_fm_mpd
+ params.mrd_loss_scale * loss_fm_mrd
+ params.mel_loss_scale * mel_loss
)
assert loss_gen_all.requires_grad == True
info = MetricsTracker()
info["frames"] = 1
info["loss_gen"] = loss_gen_all.detach().cpu().item()
info["loss_mel"] = mel_loss.detach().cpu().item()
info["loss_feature_mpd"] = loss_fm_mpd.detach().cpu().item()
info["loss_feature_mrd"] = loss_fm_mrd.detach().cpu().item()
info["loss_gen_mrd"] = loss_gen_mrd.detach().cpu().item()
info["loss_gen_mpd"] = loss_gen_mpd.detach().cpu().item()
return loss_gen_all, info
def compute_discriminator_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
features: Tensor,
audios: Tensor,
) -> Tuple[Tensor, MetricsTracker]:
device = params.device
model = model.module if isinstance(model, DDP) else model
with torch.no_grad():
audios_hat = model(features) # (B, 1, T)
real_score_mpd, gen_score_mpd, _, _ = model.mpd(y=audios, y_hat=audios_hat)
real_score_mrd, gen_score_mrd, _, _ = model.mrd(y=audios, y_hat=audios_hat)
loss_mpd, loss_mpd_real, loss_mpd_gen = model.disc_loss(
disc_real_outputs=real_score_mpd, disc_generated_outputs=gen_score_mpd
)
loss_mrd, loss_mrd_real, loss_mrd_gen = model.disc_loss(
disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
)
loss_mpd /= len(loss_mpd_real)
loss_mrd /= len(loss_mrd_real)
loss_disc_all = loss_mpd + params.mrd_loss_scale * loss_mrd
info = MetricsTracker()
# MetricsTracker will norm the loss value with "frames", set it to 1 here to
# make tot_loss look normal.
info["frames"] = 1
info["loss_disc"] = loss_disc_all.detach().cpu().item()
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
return loss_disc_all, info
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer_g: Optimizer,
optimizer_d: Optimizer,
scheduler_g: ExponentialLR,
scheduler_d: ExponentialLR,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer.
scheduler:
The learning rate scheduler, we call step() every epoch.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = batch["features_lens"].size(0)
features = batch["features"].to(device) # (B, T, F)
features_lens = batch["features_lens"].to(device)
audios = batch["audio"].to(device)
segment_frames = (
params.segment_size - params.frame_length
) // params.frame_shift + 1
start_p = random.randint(0, features_lens.min() - (segment_frames + 1))
features = features[:, start_p : start_p + segment_frames, :].permute(
0, 2, 1
) # (B, F, T)
audios = audios[
:,
start_p * params.frame_shift : start_p * params.frame_shift
+ params.segment_size,
] # (B, T)
try:
optimizer_d.zero_grad()
loss_disc, loss_disc_info = compute_discriminator_loss(
params=params,
model=model,
features=features,
audios=audios,
)
loss_disc.backward()
optimizer_d.step()
scheduler_d.step()
optimizer_g.zero_grad()
loss_gen, loss_gen_info = compute_generator_loss(
params=params,
model=model,
features=features,
audios=audios,
)
loss_gen.backward()
optimizer_g.step()
scheduler_g.step()
loss_info = loss_gen_info + loss_disc_info
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
except Exception as e:
logging.info(f"Caught exception : {e}.")
save_bad_model()
raise
if params.print_diagnostics and batch_idx == 5:
return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_lr_g = max(scheduler_g.get_last_lr())
cur_lr_d = max(scheduler_d.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
f"cur_lr_g: {cur_lr_g:.4e}, "
f"cur_lr_d: {cur_lr_d:.4e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate_gen", cur_lr_g, params.batch_idx_train
)
tb_writer.add_scalar(
"train/learning_rate_disc", cur_lr_d, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if (
params.batch_idx_train % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
rank=rank,
tb_writer=tb_writer,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss_gen"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
tb_writer: Optional[SummaryWriter] = None,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
torch.cuda.empty_cache()
model = model.module if isinstance(model, DDP) else model
device = next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
with torch.no_grad():
infer_time = 0
audio_time = 0
for batch_idx, batch in enumerate(valid_dl):
features = batch["features"] # (B, T, F)
features_lens = batch["features_lens"]
audio_time += torch.sum(features_lens)
x = features.permute(0, 2, 1) # (B, F, T)
y = batch["audio"].to(device) # (B, T)
start = time.time()
y_g_hat = model(x.to(device)) # (B, T)
infer_time += time.time() - start
if y_g_hat.size(1) > y.size(1):
y = torch.cat(
[
y,
torch.zeros(
(y.size(0), y_g_hat.size(1) - y.size(1)), device=device
),
],
dim=1,
)
else:
y = y[:, 0 : y_g_hat.size(1)]
loss_mel_error = model.melspec_loss(y_g_hat, y)
loss_info = MetricsTracker()
# MetricsTracker will norm the loss value with "frames", set it to 1 here to
# make tot_loss look normal.
loss_info["frames"] = 1
loss_info["loss_mel_error"] = loss_mel_error.item()
tot_loss = tot_loss + loss_info
if batch_idx <= 5 and rank == 0 and tb_writer is not None:
if params.batch_idx_train == params.valid_interval:
tb_writer.add_audio(
"gt/y_{}".format(batch_idx),
y[0],
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"generated/y_hat_{}".format(batch_idx),
y_g_hat[0],
params.batch_idx_train,
params.sampling_rate,
)
logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}")
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["loss_mel_error"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
torch.autograd.set_detect_anomaly(True)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_model(params)
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model = model.to(device)
generator = model.generator
mrd = model.mrd
mpd = model.mpd
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer_g = torch.optim.AdamW(
generator.parameters(),
params.learning_rate,
betas=[params.adam_b1, params.adam_b2],
)
optimizer_d = torch.optim.AdamW(
itertools.chain(mrd.parameters(), mpd.parameters()),
params.learning_rate,
betas=[params.adam_b1, params.adam_b2],
)
scheduler_g = get_cosine_schedule_with_warmup(
optimizer_g,
num_warmup_steps=params.warmup_steps,
num_training_steps=params.max_steps,
)
scheduler_d = get_cosine_schedule_with_warmup(
optimizer_d,
num_warmup_steps=params.warmup_steps,
num_training_steps=params.max_steps,
)
if checkpoints is not None:
# load state_dict for optimizers
if "optimizer_g" in checkpoints:
logging.info("Loading generator optimizer state dict")
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
if "optimizer_d" in checkpoints:
logging.info("Loading discriminator optimizer state dict")
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
# load state_dict for schedulers
if "scheduler_g" in checkpoints:
logging.info("Loading generator scheduler state dict")
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
if "scheduler_d" in checkpoints:
logging.info("Loading discriminator scheduler state dict")
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
libritts = LibriTTSDataModule(args)
train_cuts = libritts.train_clean_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = libritts.train_dataloaders(train_cuts)
valid_cuts = libritts.dev_clean_cuts()
valid_dl = libritts.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
model_avg=model_avg,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_epoch_k,
prefix="epoch",
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LibriTTSDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()