mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
First working version.
This commit is contained in:
parent
ccd2dcc9f9
commit
56d3b92f3f
178
egs/ljspeech/TTS/matcha/inference.py
Executable file
178
egs/ljspeech/TTS/matcha/inference.py
Executable file
@ -0,0 +1,178 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from matcha.hifigan.config import v1
|
||||||
|
from matcha.hifigan.denoiser import Denoiser
|
||||||
|
from matcha.hifigan.models import Generator as HiFiGAN
|
||||||
|
from matcha.text import sequence_to_text, text_to_sequence
|
||||||
|
from matcha.utils.utils import intersperse
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from train import get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.utils import AttributeDict
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=140,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=Path,
|
||||||
|
default="matcha/exp",
|
||||||
|
help="""The experiment dir.
|
||||||
|
It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocoder(checkpoint_path):
|
||||||
|
h = AttributeDict(v1)
|
||||||
|
hifigan = HiFiGAN(h).to("cpu")
|
||||||
|
hifigan.load_state_dict(
|
||||||
|
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||||
|
)
|
||||||
|
_ = hifigan.eval()
|
||||||
|
hifigan.remove_weight_norm()
|
||||||
|
return hifigan
|
||||||
|
|
||||||
|
|
||||||
|
def to_waveform(mel, vocoder, denoiser):
|
||||||
|
audio = vocoder(mel).clamp(-1, 1)
|
||||||
|
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||||
|
return audio.cpu().squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_folder(filename: str, output: dict, folder: str):
|
||||||
|
folder = Path(folder)
|
||||||
|
folder.mkdir(exist_ok=True, parents=True)
|
||||||
|
np.save(folder / f"{filename}", output["mel"].cpu().numpy())
|
||||||
|
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
|
||||||
|
|
||||||
|
|
||||||
|
def process_text(text: str):
|
||||||
|
x = torch.tensor(
|
||||||
|
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu",
|
||||||
|
)[None]
|
||||||
|
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
||||||
|
x_phones = sequence_to_text(x.squeeze(0).tolist())
|
||||||
|
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
|
||||||
|
|
||||||
|
|
||||||
|
def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None):
|
||||||
|
text_processed = process_text(text)
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
output = model.synthesise(
|
||||||
|
text_processed["x"],
|
||||||
|
text_processed["x_lengths"],
|
||||||
|
n_timesteps=n_timesteps,
|
||||||
|
temperature=temperature,
|
||||||
|
spks=spks,
|
||||||
|
length_scale=length_scale,
|
||||||
|
)
|
||||||
|
print("output.shape", list(output.keys()), output["mel"].shape)
|
||||||
|
# merge everything to one dict
|
||||||
|
output.update({"start_t": start_t, **text_processed})
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
params.update(vars(args))
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
|
||||||
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
||||||
|
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Number of ODE Solver steps
|
||||||
|
n_timesteps = 2
|
||||||
|
|
||||||
|
# Changes to the speaking rate
|
||||||
|
length_scale = 1.0
|
||||||
|
|
||||||
|
# Sampling temperature
|
||||||
|
temperature = 0.667
|
||||||
|
|
||||||
|
outputs, rtfs = [], []
|
||||||
|
rtfs_w = []
|
||||||
|
for i, text in enumerate(tqdm(texts)):
|
||||||
|
output = synthesise(
|
||||||
|
model=model,
|
||||||
|
n_timesteps=n_timesteps,
|
||||||
|
text=text,
|
||||||
|
length_scale=length_scale,
|
||||||
|
temperature=temperature,
|
||||||
|
) # , torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
|
||||||
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||||
|
|
||||||
|
# Compute Real Time Factor (RTF) with HiFi-GAN
|
||||||
|
t = (dt.datetime.now() - output["start_t"]).total_seconds()
|
||||||
|
rtf_w = t * 22050 / (output["waveform"].shape[-1])
|
||||||
|
|
||||||
|
# Pretty print
|
||||||
|
print(f"{'*' * 53}")
|
||||||
|
print(f"Input text - {i}")
|
||||||
|
print(f"{'-' * 53}")
|
||||||
|
print(output["x_orig"])
|
||||||
|
print(f"{'*' * 53}")
|
||||||
|
print(f"Phonetised text - {i}")
|
||||||
|
print(f"{'-' * 53}")
|
||||||
|
print(output["x_phones"])
|
||||||
|
print(f"{'*' * 53}")
|
||||||
|
print(f"RTF:\t\t{output['rtf']:.6f}")
|
||||||
|
print(f"RTF Waveform:\t{rtf_w:.6f}")
|
||||||
|
rtfs.append(output["rtf"])
|
||||||
|
rtfs_w.append(rtf_w)
|
||||||
|
|
||||||
|
# Save the generated waveform
|
||||||
|
save_to_folder(i, output, folder="./my-output")
|
||||||
|
|
||||||
|
print(f"Number of ODE steps: {n_timesteps}")
|
||||||
|
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
|
||||||
|
print(
|
||||||
|
f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -5,6 +5,7 @@ import random
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import matcha.utils.monotonic_align as monotonic_align
|
import matcha.utils.monotonic_align as monotonic_align
|
||||||
|
|
||||||
# from matcha import utils
|
# from matcha import utils
|
||||||
# from matcha.models.baselightningmodule import BaseLightningClass
|
# from matcha.models.baselightningmodule import BaseLightningClass
|
||||||
from matcha.models.components.flow_matching import CFM
|
from matcha.models.components.flow_matching import CFM
|
||||||
@ -30,7 +31,7 @@ class MatchaTTS(torch.nn.Module): # 🍵
|
|||||||
encoder,
|
encoder,
|
||||||
decoder,
|
decoder,
|
||||||
cfm,
|
cfm,
|
||||||
# data_statistics,
|
data_statistics,
|
||||||
out_size,
|
out_size,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
@ -71,9 +72,13 @@ class MatchaTTS(torch.nn.Module): # 🍵
|
|||||||
)
|
)
|
||||||
|
|
||||||
# self.update_data_statistics(data_statistics)
|
# self.update_data_statistics(data_statistics)
|
||||||
|
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
|
||||||
|
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
|
def synthesise(
|
||||||
|
self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates mel-spectrogram from text. Returns:
|
Generates mel-spectrogram from text. Returns:
|
||||||
1. encoder outputs
|
1. encoder outputs
|
||||||
@ -149,7 +154,17 @@ class MatchaTTS(torch.nn.Module): # 🍵
|
|||||||
"rtf": rtf,
|
"rtf": rtf,
|
||||||
}
|
}
|
||||||
|
|
||||||
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
x_lengths,
|
||||||
|
y,
|
||||||
|
y_lengths,
|
||||||
|
spks=None,
|
||||||
|
out_size=None,
|
||||||
|
cond=None,
|
||||||
|
durations=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Computes 3 losses:
|
Computes 3 losses:
|
||||||
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
||||||
@ -187,7 +202,9 @@ class MatchaTTS(torch.nn.Module): # 🍵
|
|||||||
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
||||||
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
|
factor = -0.5 * torch.ones(
|
||||||
|
mu_x.shape, dtype=mu_x.dtype, device=mu_x.device
|
||||||
|
)
|
||||||
y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
||||||
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
||||||
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
||||||
@ -206,12 +223,25 @@ class MatchaTTS(torch.nn.Module): # 🍵
|
|||||||
# - Do not need this hack for Matcha-TTS, but it works with it as well
|
# - Do not need this hack for Matcha-TTS, but it works with it as well
|
||||||
if not isinstance(out_size, type(None)):
|
if not isinstance(out_size, type(None)):
|
||||||
max_offset = (y_lengths - out_size).clamp(0)
|
max_offset = (y_lengths - out_size).clamp(0)
|
||||||
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
|
offset_ranges = list(
|
||||||
|
zip([0] * max_offset.shape[0], max_offset.cpu().numpy())
|
||||||
|
)
|
||||||
out_offset = torch.LongTensor(
|
out_offset = torch.LongTensor(
|
||||||
[torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
|
[
|
||||||
|
torch.tensor(random.choice(range(start, end)) if end > start else 0)
|
||||||
|
for start, end in offset_ranges
|
||||||
|
]
|
||||||
).to(y_lengths)
|
).to(y_lengths)
|
||||||
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
|
attn_cut = torch.zeros(
|
||||||
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
|
attn.shape[0],
|
||||||
|
attn.shape[1],
|
||||||
|
out_size,
|
||||||
|
dtype=attn.dtype,
|
||||||
|
device=attn.device,
|
||||||
|
)
|
||||||
|
y_cut = torch.zeros(
|
||||||
|
y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device
|
||||||
|
)
|
||||||
|
|
||||||
y_cut_lengths = []
|
y_cut_lengths = []
|
||||||
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
|
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
|
||||||
@ -233,12 +263,36 @@ class MatchaTTS(torch.nn.Module): # 🍵
|
|||||||
mu_y = mu_y.transpose(1, 2)
|
mu_y = mu_y.transpose(1, 2)
|
||||||
|
|
||||||
# Compute loss of the decoder
|
# Compute loss of the decoder
|
||||||
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
diff_loss, _ = self.decoder.compute_loss(
|
||||||
|
x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond
|
||||||
|
)
|
||||||
|
|
||||||
if self.prior_loss:
|
if self.prior_loss:
|
||||||
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
prior_loss = torch.sum(
|
||||||
|
0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask
|
||||||
|
)
|
||||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||||
else:
|
else:
|
||||||
prior_loss = 0
|
prior_loss = 0
|
||||||
|
|
||||||
return dur_loss, prior_loss, diff_loss, attn
|
return dur_loss, prior_loss, diff_loss, attn
|
||||||
|
|
||||||
|
def get_losses(self, batch):
|
||||||
|
x, x_lengths = batch["x"], batch["x_lengths"]
|
||||||
|
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||||
|
spks = batch["spks"]
|
||||||
|
|
||||||
|
dur_loss, prior_loss, diff_loss, *_ = self(
|
||||||
|
x=x,
|
||||||
|
x_lengths=x_lengths,
|
||||||
|
y=y,
|
||||||
|
y_lengths=y_lengths,
|
||||||
|
spks=spks,
|
||||||
|
out_size=self.out_size,
|
||||||
|
durations=batch["durations"],
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"dur_loss": dur_loss,
|
||||||
|
"prior_loss": prior_loss,
|
||||||
|
"diff_loss": diff_loss,
|
||||||
|
}
|
||||||
|
159
egs/ljspeech/TTS/matcha/test-train.py
Normal file
159
egs/ljspeech/TTS/matcha/test-train.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
from icefall.utils import AttributeDict
|
||||||
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
|
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"name": "ljspeech",
|
||||||
|
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||||
|
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||||
|
"batch_size": 32,
|
||||||
|
"num_workers": 3,
|
||||||
|
"pin_memory": False,
|
||||||
|
"cleaners": ["english_cleaners2"],
|
||||||
|
"add_blank": True,
|
||||||
|
"n_spks": 1,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"n_feats": 80,
|
||||||
|
"sample_rate": 22050,
|
||||||
|
"hop_length": 256,
|
||||||
|
"win_length": 1024,
|
||||||
|
"f_min": 0,
|
||||||
|
"f_max": 8000,
|
||||||
|
"seed": 1234,
|
||||||
|
"load_durations": False,
|
||||||
|
"data_statistics": AttributeDict(
|
||||||
|
{
|
||||||
|
"mel_mean": -5.517028331756592,
|
||||||
|
"mel_std": 2.0643954277038574,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_params() -> AttributeDict:
|
||||||
|
n_feats = 80
|
||||||
|
filter_channels_dp = 256
|
||||||
|
encoder_params_p_dropout = 0.1
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"n_vocab": 178,
|
||||||
|
"n_spks": 1, # for ljspeech.
|
||||||
|
"spk_emb_dim": 64,
|
||||||
|
"n_feats": n_feats,
|
||||||
|
"out_size": None, # or use 172
|
||||||
|
"prior_loss": True,
|
||||||
|
"use_precomputed_durations": False,
|
||||||
|
"encoder": AttributeDict(
|
||||||
|
{
|
||||||
|
"encoder_type": "RoPE Encoder", # not used
|
||||||
|
"encoder_params": AttributeDict(
|
||||||
|
{
|
||||||
|
"n_feats": n_feats,
|
||||||
|
"n_channels": 192,
|
||||||
|
"filter_channels": 768,
|
||||||
|
"filter_channels_dp": filter_channels_dp,
|
||||||
|
"n_heads": 2,
|
||||||
|
"n_layers": 6,
|
||||||
|
"kernel_size": 3,
|
||||||
|
"p_dropout": encoder_params_p_dropout,
|
||||||
|
"spk_emb_dim": 64,
|
||||||
|
"n_spks": 1,
|
||||||
|
"prenet": True,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"duration_predictor_params": AttributeDict(
|
||||||
|
{
|
||||||
|
"filter_channels_dp": filter_channels_dp,
|
||||||
|
"kernel_size": 3,
|
||||||
|
"p_dropout": encoder_params_p_dropout,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"decoder": AttributeDict(
|
||||||
|
{
|
||||||
|
"channels": [256, 256],
|
||||||
|
"dropout": 0.05,
|
||||||
|
"attention_head_dim": 64,
|
||||||
|
"n_blocks": 1,
|
||||||
|
"num_mid_blocks": 2,
|
||||||
|
"num_heads": 2,
|
||||||
|
"act_fn": "snakebeta",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"cfm": AttributeDict(
|
||||||
|
{
|
||||||
|
"name": "CFM",
|
||||||
|
"solver": "euler",
|
||||||
|
"sigma_min": 1e-4,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"optimizer": AttributeDict(
|
||||||
|
{
|
||||||
|
"lr": 1e-4,
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_params():
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"model": _get_model_params(),
|
||||||
|
"data": _get_data_params(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(params):
|
||||||
|
m = MatchaTTS(**params.model)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
data_module = TextMelDataModule(hparams=params.data)
|
||||||
|
if False:
|
||||||
|
for b in data_module.train_dataloader():
|
||||||
|
assert isinstance(b, dict)
|
||||||
|
# b.keys()
|
||||||
|
# ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations']
|
||||||
|
# x: [batch_size, 289], torch.int64
|
||||||
|
# x_lengths: [batch_size], torch.int64
|
||||||
|
# y: [batch_size, n_feats, num_frames], torch.float32
|
||||||
|
# y_lengths: [batch_size], torch.int64
|
||||||
|
# spks: None
|
||||||
|
# filepaths: list, (batch_size,)
|
||||||
|
# x_texts: list, (batch_size,)
|
||||||
|
# durations: None
|
||||||
|
|
||||||
|
m = get_model(params)
|
||||||
|
print(m)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in m.parameters()])
|
||||||
|
print(f"Number of parameters: {num_param}")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -2,12 +2,111 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
from icefall.utils import AttributeDict
|
|
||||||
from matcha.models.matcha_tts import MatchaTTS
|
|
||||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||||
|
from icefall.env import get_env_info
|
||||||
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from utils2 import MetricsTracker, plot_feature
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint, save_checkpoint
|
||||||
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tensorboard",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Should various information be logged in tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of epochs to train.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-epoch",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Resume training from this epoch. It should be positive.
|
||||||
|
If larger than 1, it will load checkpoint from
|
||||||
|
exp-dir/epoch-{start_epoch-1}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=Path,
|
||||||
|
default="matcha/exp",
|
||||||
|
help="""The experiment dir.
|
||||||
|
It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every-n",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="""Save checkpoint after processing this number of epochs"
|
||||||
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
|
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||||
|
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
|
||||||
|
Since it will take around 1000 epochs, we suggest using a large
|
||||||
|
save_every_n to save disk space.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use half precision training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_statistics():
|
||||||
|
return AttributeDict(
|
||||||
|
{
|
||||||
|
"mel_mean": -5.517028331756592,
|
||||||
|
"mel_std": 2.0643954277038574,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_data_params() -> AttributeDict:
|
def _get_data_params() -> AttributeDict:
|
||||||
@ -16,7 +115,6 @@ def _get_data_params() -> AttributeDict:
|
|||||||
"name": "ljspeech",
|
"name": "ljspeech",
|
||||||
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||||
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||||
"batch_size": 32,
|
|
||||||
"num_workers": 3,
|
"num_workers": 3,
|
||||||
"pin_memory": False,
|
"pin_memory": False,
|
||||||
"cleaners": ["english_cleaners2"],
|
"cleaners": ["english_cleaners2"],
|
||||||
@ -31,12 +129,7 @@ def _get_data_params() -> AttributeDict:
|
|||||||
"f_max": 8000,
|
"f_max": 8000,
|
||||||
"seed": 1234,
|
"seed": 1234,
|
||||||
"load_durations": False,
|
"load_durations": False,
|
||||||
"data_statistics": AttributeDict(
|
"data_statistics": get_data_statistics(),
|
||||||
{
|
|
||||||
"mel_mean": -5.517028331756592,
|
|
||||||
"mel_std": 2.0643954277038574,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -55,6 +148,7 @@ def _get_model_params() -> AttributeDict:
|
|||||||
"out_size": None, # or use 172
|
"out_size": None, # or use 172
|
||||||
"prior_loss": True,
|
"prior_loss": True,
|
||||||
"use_precomputed_durations": False,
|
"use_precomputed_durations": False,
|
||||||
|
"data_statistics": get_data_statistics(),
|
||||||
"encoder": AttributeDict(
|
"encoder": AttributeDict(
|
||||||
{
|
{
|
||||||
"encoder_type": "RoPE Encoder", # not used
|
"encoder_type": "RoPE Encoder", # not used
|
||||||
@ -115,42 +209,368 @@ def _get_model_params() -> AttributeDict:
|
|||||||
def get_params():
|
def get_params():
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"model": _get_model_params(),
|
"model_args": _get_model_params(),
|
||||||
"data": _get_data_params(),
|
"data_args": _get_data_params(),
|
||||||
|
"best_train_loss": float("inf"),
|
||||||
|
"best_valid_loss": float("inf"),
|
||||||
|
"best_train_epoch": -1,
|
||||||
|
"best_valid_epoch": -1,
|
||||||
|
"batch_idx_train": -1, # 0
|
||||||
|
"log_interval": 50,
|
||||||
|
"valid_interval": 2000,
|
||||||
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_model(params):
|
def get_model(params):
|
||||||
m = MatchaTTS(**params.model)
|
m = MatchaTTS(**params.model_args)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_if_available(
|
||||||
|
params: AttributeDict, model: nn.Module
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
|
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||||
|
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
|
and `best_valid_loss` in `params`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
The return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
Returns:
|
||||||
|
Return a dict containing previously saved training info.
|
||||||
|
"""
|
||||||
|
if params.start_epoch > 1:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert filename.is_file(), f"{filename} does not exist!"
|
||||||
|
|
||||||
|
saved_params = load_checkpoint(filename, model=model)
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
"best_train_epoch",
|
||||||
|
"best_valid_epoch",
|
||||||
|
"batch_idx_train",
|
||||||
|
"best_train_loss",
|
||||||
|
"best_valid_loss",
|
||||||
|
]
|
||||||
|
for k in keys:
|
||||||
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> MetricsTracker:
|
||||||
|
"""Run the validation process."""
|
||||||
|
model.eval()
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
|
# used to summary the stats over iterations
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
for key, value in batch.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
batch[key] = value.to(device)
|
||||||
|
losses = model.get_losses(batch)
|
||||||
|
loss = sum(losses.values())
|
||||||
|
|
||||||
|
batch_size = batch["x"].shape[0]
|
||||||
|
|
||||||
|
loss_info = MetricsTracker()
|
||||||
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
|
||||||
|
for key, value in losses.items():
|
||||||
|
v = value.detach().item()
|
||||||
|
loss_info[key] = v * batch_size
|
||||||
|
s += v * batch_size
|
||||||
|
|
||||||
|
loss_info["tot_loss"] = s
|
||||||
|
|
||||||
|
# summary stats
|
||||||
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
tot_loss.reduce(device)
|
||||||
|
|
||||||
|
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||||
|
if loss_value < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
|
return tot_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
optimizer: Optimizer,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
scaler: GradScaler,
|
||||||
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
|
The training loss from the mean of all frames is saved in
|
||||||
|
`params.train_loss`. It runs the validation process every
|
||||||
|
`params.valid_interval` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training.
|
||||||
|
optimizer:
|
||||||
|
The optimizer.
|
||||||
|
train_dl:
|
||||||
|
Dataloader for the training dataset.
|
||||||
|
valid_dl:
|
||||||
|
Dataloader for the validation dataset.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training.
|
||||||
|
tb_writer:
|
||||||
|
Writer to write log messages to tensorboard.
|
||||||
|
"""
|
||||||
|
model.train()
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
|
# used to track the stats over iterations in one epoch
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
saved_bad_model = False
|
||||||
|
|
||||||
|
# used to track the stats over iterations in one epoch
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
saved_bad_model = False
|
||||||
|
|
||||||
|
def save_bad_model(suffix: str = ""):
|
||||||
|
save_checkpoint(
|
||||||
|
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
for key, value in batch.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
batch[key] = value.to(device)
|
||||||
|
|
||||||
|
batch_size = batch["x"].shape[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
losses = model.get_losses(batch)
|
||||||
|
|
||||||
|
loss = sum(losses.values())
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
|
||||||
|
loss_info = MetricsTracker()
|
||||||
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
|
||||||
|
for key, value in losses.items():
|
||||||
|
v = value.detach().item()
|
||||||
|
loss_info[key] = v * batch_size
|
||||||
|
s += v * batch_size
|
||||||
|
|
||||||
|
loss_info["tot_loss"] = s
|
||||||
|
|
||||||
|
tot_loss = tot_loss + loss_info
|
||||||
|
except: # noqa
|
||||||
|
save_bad_model()
|
||||||
|
raise
|
||||||
|
|
||||||
|
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||||
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
|
# behavior depending on the current grad scale.
|
||||||
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
|
if cur_grad_scale < 8.0 or (
|
||||||
|
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||||
|
):
|
||||||
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
|
if cur_grad_scale < 0.01:
|
||||||
|
if not saved_bad_model:
|
||||||
|
save_bad_model(suffix="-first-warning")
|
||||||
|
saved_bad_model = True
|
||||||
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
|
if cur_grad_scale < 1.0e-05:
|
||||||
|
save_bad_model()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.batch_idx_train % params.log_interval == 0:
|
||||||
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||||
|
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||||
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
loss_info.write_summary(
|
||||||
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
|
if params.use_fp16:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.batch_idx_train % params.valid_interval == 1:
|
||||||
|
logging.info("Computing validation loss")
|
||||||
|
valid_info = compute_validation_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
model.train()
|
||||||
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
valid_info.write_summary(
|
||||||
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||||
|
params.train_loss = loss_value
|
||||||
|
if params.train_loss < params.best_train_loss:
|
||||||
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
params = get_params()
|
params = get_params()
|
||||||
|
|
||||||
data_module = TextMelDataModule(hparams=params.data)
|
params.update(vars(args))
|
||||||
if False:
|
|
||||||
for b in data_module.train_dataloader():
|
|
||||||
assert isinstance(b, dict)
|
|
||||||
# b.keys()
|
|
||||||
# ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations']
|
|
||||||
# x: [batch_size, 289], torch.int64
|
|
||||||
# x_lengths: [batch_size], torch.int64
|
|
||||||
# y: [batch_size, n_feats, num_frames], torch.float32
|
|
||||||
# y_lengths: [batch_size], torch.int64
|
|
||||||
# spks: None
|
|
||||||
# filepaths: list, (batch_size,)
|
|
||||||
# x_texts: list, (batch_size,)
|
|
||||||
# durations: None
|
|
||||||
|
|
||||||
m = get_model(params)
|
params.data_args.batch_size = params.batch_size
|
||||||
print(m)
|
del params.batch_size
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in m.parameters()])
|
fix_random_seed(params.seed)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
|
logging.info("Training started")
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
print(f"Device: {device}")
|
||||||
|
print(f"Device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
print(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of parameters: {num_param}")
|
||||||
print(f"Number of parameters: {num_param}")
|
print(f"Number of parameters: {num_param}")
|
||||||
|
|
||||||
|
logging.info("About to create datamodule")
|
||||||
|
data_module = TextMelDataModule(hparams=params.data_args)
|
||||||
|
|
||||||
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
|
||||||
|
|
||||||
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
logging.info("Loading grad scaler state dict")
|
||||||
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
|
train_dl = data_module.train_dataloader()
|
||||||
|
valid_dl = data_module.val_dataloader()
|
||||||
|
|
||||||
|
rank = 0
|
||||||
|
|
||||||
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
|
logging.info(f"Start epoch {epoch}")
|
||||||
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
|
|
||||||
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
|
train_one_epoch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
train_dl=train_dl,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
scaler=scaler,
|
||||||
|
tb_writer=tb_writer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||||
|
save_checkpoint(
|
||||||
|
filename=filename,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
if rank == 0:
|
||||||
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
|
|
||||||
|
if params.best_valid_epoch == params.cur_epoch:
|
||||||
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
1
egs/ljspeech/TTS/matcha/utils2.py
Symbolic link
1
egs/ljspeech/TTS/matcha/utils2.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../vits/utils.py
|
Loading…
x
Reference in New Issue
Block a user