First working version.

This commit is contained in:
Fangjun Kuang 2024-10-16 19:35:35 +08:00
parent ccd2dcc9f9
commit 56d3b92f3f
5 changed files with 853 additions and 41 deletions

View File

@ -0,0 +1,178 @@
#!/usr/bin/env python3
import argparse
import datetime as dt
import logging
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import intersperse
from tqdm.auto import tqdm
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=140,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
return parser
def load_vocoder(checkpoint_path):
h = AttributeDict(v1)
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def to_waveform(mel, vocoder, denoiser):
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
def save_to_folder(filename: str, output: dict, folder: str):
folder = Path(folder)
folder.mkdir(exist_ok=True, parents=True)
np.save(folder / f"{filename}", output["mel"].cpu().numpy())
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
def process_text(text: str):
x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
dtype=torch.long,
device="cpu",
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
x_phones = sequence_to_text(x.squeeze(0).tolist())
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None):
text_processed = process_text(text)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
print("output.shape", list(output.keys()), output["mel"].shape)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.eval()
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
denoiser = Denoiser(vocoder, mode="zeros")
texts = [
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
]
# Number of ODE Solver steps
n_timesteps = 2
# Changes to the speaking rate
length_scale = 1.0
# Sampling temperature
temperature = 0.667
outputs, rtfs = [], []
rtfs_w = []
for i, text in enumerate(tqdm(texts)):
output = synthesise(
model=model,
n_timesteps=n_timesteps,
text=text,
length_scale=length_scale,
temperature=temperature,
) # , torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
# Compute Real Time Factor (RTF) with HiFi-GAN
t = (dt.datetime.now() - output["start_t"]).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])
# Pretty print
print(f"{'*' * 53}")
print(f"Input text - {i}")
print(f"{'-' * 53}")
print(output["x_orig"])
print(f"{'*' * 53}")
print(f"Phonetised text - {i}")
print(f"{'-' * 53}")
print(output["x_phones"])
print(f"{'*' * 53}")
print(f"RTF:\t\t{output['rtf']:.6f}")
print(f"RTF Waveform:\t{rtf_w:.6f}")
rtfs.append(output["rtf"])
rtfs_w.append(rtf_w)
# Save the generated waveform
save_to_folder(i, output, folder="./my-output")
print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
print(
f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}"
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -5,6 +5,7 @@ import random
import torch
import matcha.utils.monotonic_align as monotonic_align
# from matcha import utils
# from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM
@ -30,7 +31,7 @@ class MatchaTTS(torch.nn.Module): # 🍵
encoder,
decoder,
cfm,
# data_statistics,
data_statistics,
out_size,
optimizer=None,
scheduler=None,
@ -71,9 +72,13 @@ class MatchaTTS(torch.nn.Module): # 🍵
)
# self.update_data_statistics(data_statistics)
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
@torch.inference_mode()
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
def synthesise(
self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0
):
"""
Generates mel-spectrogram from text. Returns:
1. encoder outputs
@ -149,7 +154,17 @@ class MatchaTTS(torch.nn.Module): # 🍵
"rtf": rtf,
}
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
def forward(
self,
x,
x_lengths,
y,
y_lengths,
spks=None,
out_size=None,
cond=None,
durations=None,
):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
@ -187,7 +202,9 @@ class MatchaTTS(torch.nn.Module): # 🍵
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
factor = -0.5 * torch.ones(
mu_x.shape, dtype=mu_x.dtype, device=mu_x.device
)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
@ -206,12 +223,25 @@ class MatchaTTS(torch.nn.Module): # 🍵
# - Do not need this hack for Matcha-TTS, but it works with it as well
if not isinstance(out_size, type(None)):
max_offset = (y_lengths - out_size).clamp(0)
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
offset_ranges = list(
zip([0] * max_offset.shape[0], max_offset.cpu().numpy())
)
out_offset = torch.LongTensor(
[torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
[
torch.tensor(random.choice(range(start, end)) if end > start else 0)
for start, end in offset_ranges
]
).to(y_lengths)
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
attn_cut = torch.zeros(
attn.shape[0],
attn.shape[1],
out_size,
dtype=attn.dtype,
device=attn.device,
)
y_cut = torch.zeros(
y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device
)
y_cut_lengths = []
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
@ -233,12 +263,36 @@ class MatchaTTS(torch.nn.Module): # 🍵
mu_y = mu_y.transpose(1, 2)
# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
diff_loss, _ = self.decoder.compute_loss(
x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond
)
if self.prior_loss:
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = torch.sum(
0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask
)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss, attn
def get_losses(self, batch):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
dur_loss, prior_loss, diff_loss, *_ = self(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
out_size=self.out_size,
durations=batch["durations"],
)
return {
"dur_loss": dur_loss,
"prior_loss": prior_loss,
"diff_loss": diff_loss,
}

View File

@ -0,0 +1,159 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import torch
from icefall.utils import AttributeDict
from matcha.models.matcha_tts import MatchaTTS
from matcha.data.text_mel_datamodule import TextMelDataModule
def _get_data_params() -> AttributeDict:
params = AttributeDict(
{
"name": "ljspeech",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
"batch_size": 32,
"num_workers": 3,
"pin_memory": False,
"cleaners": ["english_cleaners2"],
"add_blank": True,
"n_spks": 1,
"n_fft": 1024,
"n_feats": 80,
"sample_rate": 22050,
"hop_length": 256,
"win_length": 1024,
"f_min": 0,
"f_max": 8000,
"seed": 1234,
"load_durations": False,
"data_statistics": AttributeDict(
{
"mel_mean": -5.517028331756592,
"mel_std": 2.0643954277038574,
}
),
}
)
return params
def _get_model_params() -> AttributeDict:
n_feats = 80
filter_channels_dp = 256
encoder_params_p_dropout = 0.1
params = AttributeDict(
{
"n_vocab": 178,
"n_spks": 1, # for ljspeech.
"spk_emb_dim": 64,
"n_feats": n_feats,
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
"encoder_params": AttributeDict(
{
"n_feats": n_feats,
"n_channels": 192,
"filter_channels": 768,
"filter_channels_dp": filter_channels_dp,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
"spk_emb_dim": 64,
"n_spks": 1,
"prenet": True,
}
),
"duration_predictor_params": AttributeDict(
{
"filter_channels_dp": filter_channels_dp,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
}
),
}
),
"decoder": AttributeDict(
{
"channels": [256, 256],
"dropout": 0.05,
"attention_head_dim": 64,
"n_blocks": 1,
"num_mid_blocks": 2,
"num_heads": 2,
"act_fn": "snakebeta",
}
),
"cfm": AttributeDict(
{
"name": "CFM",
"solver": "euler",
"sigma_min": 1e-4,
}
),
"optimizer": AttributeDict(
{
"lr": 1e-4,
"weight_decay": 0.0,
}
),
}
)
return params
def get_params():
params = AttributeDict(
{
"model": _get_model_params(),
"data": _get_data_params(),
}
)
return params
def get_model(params):
m = MatchaTTS(**params.model)
return m
def main():
params = get_params()
data_module = TextMelDataModule(hparams=params.data)
if False:
for b in data_module.train_dataloader():
assert isinstance(b, dict)
# b.keys()
# ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations']
# x: [batch_size, 289], torch.int64
# x_lengths: [batch_size], torch.int64
# y: [batch_size, n_feats, num_frames], torch.float32
# y_lengths: [batch_size], torch.int64
# spks: None
# filepaths: list, (batch_size,)
# x_texts: list, (batch_size,)
# durations: None
m = get_model(params)
print(m)
num_param = sum([p.numel() for p in m.parameters()])
print(f"Number of parameters: {num_param}")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -2,12 +2,111 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import torch
from icefall.utils import AttributeDict
from matcha.models.matcha_tts import MatchaTTS
import torch.nn as nn
from lhotse.utils import fix_random_seed
from matcha.data.text_mel_datamodule import TextMelDataModule
from icefall.env import get_env_info
from matcha.models.matcha_tts import MatchaTTS
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from utils2 import MetricsTracker, plot_feature
from icefall.checkpoint import load_checkpoint, save_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=1000,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--save-every-n",
type=int,
default=10,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
Since it will take around 1000 epochs, we suggest using a large
save_every_n to save disk space.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
)
return parser
def get_data_statistics():
return AttributeDict(
{
"mel_mean": -5.517028331756592,
"mel_std": 2.0643954277038574,
}
)
def _get_data_params() -> AttributeDict:
@ -16,7 +115,6 @@ def _get_data_params() -> AttributeDict:
"name": "ljspeech",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
"batch_size": 32,
"num_workers": 3,
"pin_memory": False,
"cleaners": ["english_cleaners2"],
@ -31,12 +129,7 @@ def _get_data_params() -> AttributeDict:
"f_max": 8000,
"seed": 1234,
"load_durations": False,
"data_statistics": AttributeDict(
{
"mel_mean": -5.517028331756592,
"mel_std": 2.0643954277038574,
}
),
"data_statistics": get_data_statistics(),
}
)
return params
@ -55,6 +148,7 @@ def _get_model_params() -> AttributeDict:
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"data_statistics": get_data_statistics(),
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
@ -115,42 +209,368 @@ def _get_model_params() -> AttributeDict:
def get_params():
params = AttributeDict(
{
"model": _get_model_params(),
"data": _get_data_params(),
"model_args": _get_model_params(),
"data_args": _get_data_params(),
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 50,
"valid_interval": 2000,
"env_info": get_env_info(),
}
)
return params
def get_model(params):
m = MatchaTTS(**params.model)
m = MatchaTTS(**params.model_args)
return m
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.to(device)
losses = model.get_losses(batch)
loss = sum(losses.values())
batch_size = batch["x"].shape[0]
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
# summary stats
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer: Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer=optimizer,
scaler=scaler,
rank=rank,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.to(device)
batch_size = batch["x"].shape[0]
try:
with autocast(enabled=params.use_fp16):
losses = model.get_losses(batch)
loss = sum(losses.values())
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if params.batch_idx_train % params.valid_interval == 1:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
rank=rank,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
data_module = TextMelDataModule(hparams=params.data)
if False:
for b in data_module.train_dataloader():
assert isinstance(b, dict)
# b.keys()
# ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations']
# x: [batch_size, 289], torch.int64
# x_lengths: [batch_size], torch.int64
# y: [batch_size, n_feats, num_frames], torch.float32
# y_lengths: [batch_size], torch.int64
# spks: None
# filepaths: list, (batch_size,)
# x_texts: list, (batch_size,)
# durations: None
params.update(vars(args))
m = get_model(params)
print(m)
params.data_args.batch_size = params.batch_size
del params.batch_size
num_param = sum([p.numel() for p in m.parameters()])
fix_random_seed(params.seed)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
print(f"Device: {device}")
print(f"Device: {device}")
logging.info(params)
print(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of parameters: {num_param}")
print(f"Number of parameters: {num_param}")
logging.info("About to create datamodule")
data_module = TextMelDataModule(hparams=params.data_args)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
train_dl = data_module.train_dataloader()
valid_dl = data_module.val_dataloader()
rank = 0
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
)
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer=optimizer,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

View File

@ -0,0 +1 @@
../vits/utils.py