mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
remove more unused code
This commit is contained in:
parent
f6328edf5b
commit
10c099ac90
@ -5,7 +5,6 @@ from typing import Any, Dict
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
|
||||
from inference import load_vocoder
|
||||
|
||||
|
||||
|
@ -2,17 +2,17 @@
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from matcha.hifigan.config import v1, v2, v3
|
||||
from matcha.hifigan.denoiser import Denoiser
|
||||
from tokenizer import Tokenizer
|
||||
from matcha.hifigan.models import Generator as HiFiGAN
|
||||
from tokenizer import Tokenizer
|
||||
from tqdm.auto import tqdm
|
||||
from train import get_model, get_params
|
||||
|
||||
|
@ -1,223 +0,0 @@
|
||||
"""
|
||||
This is a base lightning module that can be used to train a model.
|
||||
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
|
||||
"""
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch.utilities import grad_norm
|
||||
|
||||
from matcha import utils
|
||||
from matcha.utils.utils import plot_tensor
|
||||
|
||||
log = utils.get_pylogger(__name__)
|
||||
|
||||
|
||||
class BaseLightningClass(LightningModule, ABC):
|
||||
def update_data_statistics(self, data_statistics):
|
||||
if data_statistics is None:
|
||||
data_statistics = {
|
||||
"mel_mean": 0.0,
|
||||
"mel_std": 1.0,
|
||||
}
|
||||
|
||||
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
|
||||
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
optimizer = self.hparams.optimizer(params=self.parameters())
|
||||
if self.hparams.scheduler not in (None, {}):
|
||||
scheduler_args = {}
|
||||
# Manage last epoch for exponential schedulers
|
||||
if (
|
||||
"last_epoch"
|
||||
in inspect.signature(self.hparams.scheduler.scheduler).parameters
|
||||
):
|
||||
if hasattr(self, "ckpt_loaded_epoch"):
|
||||
current_epoch = self.ckpt_loaded_epoch - 1
|
||||
else:
|
||||
current_epoch = -1
|
||||
|
||||
scheduler_args.update({"optimizer": optimizer})
|
||||
scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
|
||||
scheduler.last_epoch = current_epoch
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"interval": self.hparams.scheduler.lightning_args.interval,
|
||||
"frequency": self.hparams.scheduler.lightning_args.frequency,
|
||||
"name": "learning_rate",
|
||||
},
|
||||
}
|
||||
|
||||
return {"optimizer": optimizer}
|
||||
|
||||
def get_losses(self, batch):
|
||||
x, x_lengths = batch["x"], batch["x_lengths"]
|
||||
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||
spks = batch["spks"]
|
||||
|
||||
dur_loss, prior_loss, diff_loss, *_ = self(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
y=y,
|
||||
y_lengths=y_lengths,
|
||||
spks=spks,
|
||||
out_size=self.out_size,
|
||||
durations=batch["durations"],
|
||||
)
|
||||
return {
|
||||
"dur_loss": dur_loss,
|
||||
"prior_loss": prior_loss,
|
||||
"diff_loss": diff_loss,
|
||||
}
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
self.ckpt_loaded_epoch = checkpoint[
|
||||
"epoch"
|
||||
] # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
def training_step(self, batch: Any, batch_idx: int):
|
||||
loss_dict = self.get_losses(batch)
|
||||
self.log(
|
||||
"step",
|
||||
float(self.global_step),
|
||||
on_step=True,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
self.log(
|
||||
"sub_loss/train_dur_loss",
|
||||
loss_dict["dur_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"sub_loss/train_prior_loss",
|
||||
loss_dict["prior_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"sub_loss/train_diff_loss",
|
||||
loss_dict["diff_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
total_loss = sum(loss_dict.values())
|
||||
self.log(
|
||||
"loss/train",
|
||||
total_loss,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
return {"loss": total_loss, "log": loss_dict}
|
||||
|
||||
def validation_step(self, batch: Any, batch_idx: int):
|
||||
loss_dict = self.get_losses(batch)
|
||||
self.log(
|
||||
"sub_loss/val_dur_loss",
|
||||
loss_dict["dur_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"sub_loss/val_prior_loss",
|
||||
loss_dict["prior_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"sub_loss/val_diff_loss",
|
||||
loss_dict["diff_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
total_loss = sum(loss_dict.values())
|
||||
self.log(
|
||||
"loss/val",
|
||||
total_loss,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
def on_validation_end(self) -> None:
|
||||
if self.trainer.is_global_zero:
|
||||
one_batch = next(iter(self.trainer.val_dataloaders))
|
||||
if self.current_epoch == 0:
|
||||
log.debug("Plotting original samples")
|
||||
for i in range(2):
|
||||
y = one_batch["y"][i].unsqueeze(0).to(self.device)
|
||||
self.logger.experiment.add_image(
|
||||
f"original/{i}",
|
||||
plot_tensor(y.squeeze().cpu()),
|
||||
self.current_epoch,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
log.debug("Synthesising...")
|
||||
for i in range(2):
|
||||
x = one_batch["x"][i].unsqueeze(0).to(self.device)
|
||||
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
|
||||
spks = (
|
||||
one_batch["spks"][i].unsqueeze(0).to(self.device)
|
||||
if one_batch["spks"] is not None
|
||||
else None
|
||||
)
|
||||
output = self.synthesise(
|
||||
x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks
|
||||
)
|
||||
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
|
||||
attn = output["attn"]
|
||||
self.logger.experiment.add_image(
|
||||
f"generated_enc/{i}",
|
||||
plot_tensor(y_enc.squeeze().cpu()),
|
||||
self.current_epoch,
|
||||
dataformats="HWC",
|
||||
)
|
||||
self.logger.experiment.add_image(
|
||||
f"generated_dec/{i}",
|
||||
plot_tensor(y_dec.squeeze().cpu()),
|
||||
self.current_epoch,
|
||||
dataformats="HWC",
|
||||
)
|
||||
self.logger.experiment.add_image(
|
||||
f"alignment/{i}",
|
||||
plot_tensor(attn.squeeze().cpu()),
|
||||
self.current_epoch,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
def on_before_optimizer_step(self, optimizer):
|
||||
self.log_dict(
|
||||
{f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}
|
||||
)
|
@ -7,7 +7,6 @@ import torch.nn.functional as F
|
||||
from conformer import ConformerBlock
|
||||
from diffusers.models.activations import get_activation
|
||||
from einops import pack, rearrange, repeat
|
||||
|
||||
from matcha.models.components.transformer import BasicTransformerBlock
|
||||
|
||||
|
||||
|
@ -2,13 +2,8 @@ from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from matcha.models.components.decoder import Decoder
|
||||
|
||||
# from matcha.utils.pylogger import get_pylogger
|
||||
|
||||
# log = get_pylogger(__name__)
|
||||
|
||||
|
||||
class BASECFM(torch.nn.Module, ABC):
|
||||
def __init__(
|
||||
|
@ -5,11 +5,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
# import matcha.utils as utils
|
||||
from matcha.utils.model import sequence_mask
|
||||
|
||||
# log = utils.get_pylogger(__name__)
|
||||
from matcha.model import sequence_mask
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
@ -2,23 +2,17 @@ import datetime as dt
|
||||
import math
|
||||
import random
|
||||
|
||||
import matcha.monotonic_align as monotonic_align
|
||||
import torch
|
||||
|
||||
import matcha.utils.monotonic_align as monotonic_align
|
||||
|
||||
# from matcha import utils
|
||||
# from matcha.models.baselightningmodule import BaseLightningClass
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
from matcha.models.components.text_encoder import TextEncoder
|
||||
from matcha.utils.model import (
|
||||
from matcha.model import (
|
||||
denormalize,
|
||||
duration_loss,
|
||||
fix_len_compatibility,
|
||||
generate_path,
|
||||
sequence_mask,
|
||||
)
|
||||
|
||||
# log = utils.get_pylogger(__name__)
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
from matcha.models.components.text_encoder import TextEncoder
|
||||
|
||||
|
||||
class MatchaTTS(torch.nn.Module): # 🍵
|
||||
|
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from matcha.utils.monotonic_align.core import maximum_path_c
|
||||
from matcha.monotonic_align.core import maximum_path_c
|
||||
|
||||
|
||||
def maximum_path(value, mask):
|
@ -1,13 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
import datetime as dt
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from tokenizer import Tokenizer
|
||||
import datetime as dt
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from inference import load_vocoder
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
class OnnxHifiGANModel:
|
||||
|
@ -14,15 +14,15 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from matcha.model import fix_len_compatibility
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.tokenizer import Tokenizer
|
||||
from matcha.utils.model import fix_len_compatibility
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils2 import MetricsTracker
|
||||
from utils import MetricsTracker
|
||||
|
||||
from icefall.checkpoint import load_checkpoint, save_checkpoint
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
|
@ -1,6 +0,0 @@
|
||||
# from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
||||
# from matcha.utils.logging_utils import log_hyperparameters
|
||||
# from matcha.utils.pylogger import get_pylogger
|
||||
# from matcha.utils.rich_utils import enforce_tags, print_config_tree
|
||||
# from matcha.utils.utils import extras, get_metric_value, task_wrapper
|
||||
from matcha.utils.utils import intersperse
|
@ -1,123 +0,0 @@
|
||||
r"""
|
||||
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
||||
when needed.
|
||||
|
||||
Parameters from hparam.py will be used
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
from hydra import compose, initialize
|
||||
from omegaconf import open_dict
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
from matcha.utils.logging_utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def compute_data_statistics(
|
||||
data_loader: torch.utils.data.DataLoader, out_channels: int
|
||||
):
|
||||
"""Generate data mean and standard deviation helpful in data normalisation
|
||||
|
||||
Args:
|
||||
data_loader (torch.utils.data.Dataloader): _description_
|
||||
out_channels (int): mel spectrogram channels
|
||||
"""
|
||||
total_mel_sum = 0
|
||||
total_mel_sq_sum = 0
|
||||
total_mel_len = 0
|
||||
|
||||
for batch in tqdm(data_loader, leave=False):
|
||||
mels = batch["y"]
|
||||
mel_lengths = batch["y_lengths"]
|
||||
|
||||
total_mel_len += torch.sum(mel_lengths)
|
||||
total_mel_sum += torch.sum(mels)
|
||||
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
|
||||
|
||||
data_mean = total_mel_sum / (total_mel_len * out_channels)
|
||||
data_std = torch.sqrt(
|
||||
(total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)
|
||||
)
|
||||
|
||||
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-config",
|
||||
type=str,
|
||||
default="vctk.yaml",
|
||||
help="The name of the yaml config file under configs/data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default="256",
|
||||
help="Can have increased batch size for faster computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
action="store_true",
|
||||
default=False,
|
||||
required=False,
|
||||
help="force overwrite the file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
output_file = Path(args.input_config).with_suffix(".json")
|
||||
|
||||
if os.path.exists(output_file) and not args.force:
|
||||
print("File already exists. Use -f to force overwrite")
|
||||
sys.exit(1)
|
||||
|
||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||
cfg = compose(
|
||||
config_name=args.input_config, return_hydra_config=True, overrides=[]
|
||||
)
|
||||
|
||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||
|
||||
with open_dict(cfg):
|
||||
print(cfg)
|
||||
del cfg["hydra"]
|
||||
del cfg["_target_"]
|
||||
cfg["data_statistics"] = None
|
||||
cfg["seed"] = 1234
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(
|
||||
os.path.join(root_path, cfg["train_filelist_path"])
|
||||
)
|
||||
cfg["valid_filelist_path"] = str(
|
||||
os.path.join(root_path, cfg["valid_filelist_path"])
|
||||
)
|
||||
cfg["load_durations"] = False
|
||||
|
||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||
text_mel_datamodule.setup()
|
||||
data_loader = text_mel_datamodule.train_dataloader()
|
||||
log.info("Dataloader loaded! Now computing stats...")
|
||||
params = compute_data_statistics(data_loader, cfg["n_feats"])
|
||||
print(params)
|
||||
json.dump(
|
||||
params,
|
||||
open(output_file, "w"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,215 +0,0 @@
|
||||
r"""
|
||||
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
||||
when needed.
|
||||
|
||||
Parameters from hparam.py will be used
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightning
|
||||
import numpy as np
|
||||
import rootutils
|
||||
import torch
|
||||
from hydra import compose, initialize
|
||||
from omegaconf import open_dict
|
||||
from torch import nn
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from matcha.cli import get_device
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.utils.logging_utils import pylogger
|
||||
from matcha.utils.utils import get_phoneme_durations
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def save_durations_to_folder(
|
||||
attn: torch.Tensor,
|
||||
x_length: int,
|
||||
y_length: int,
|
||||
filepath: str,
|
||||
output_folder: Path,
|
||||
text: str,
|
||||
):
|
||||
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
||||
durations_json = get_phoneme_durations(durations, text)
|
||||
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
|
||||
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
|
||||
json.dump(durations_json, f, indent=4, ensure_ascii=False)
|
||||
|
||||
np.save(output, durations)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def compute_durations(
|
||||
data_loader: torch.utils.data.DataLoader,
|
||||
model: nn.Module,
|
||||
device: torch.device,
|
||||
output_folder,
|
||||
):
|
||||
"""Generate durations from the model for each datapoint and save it in a folder
|
||||
|
||||
Args:
|
||||
data_loader (torch.utils.data.DataLoader): Dataloader
|
||||
model (nn.Module): MatchaTTS model
|
||||
device (torch.device): GPU or CPU
|
||||
"""
|
||||
|
||||
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
|
||||
x, x_lengths = batch["x"], batch["x_lengths"]
|
||||
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||
spks = batch["spks"]
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
x_lengths = x_lengths.to(device)
|
||||
y_lengths = y_lengths.to(device)
|
||||
spks = spks.to(device) if spks is not None else None
|
||||
|
||||
_, _, _, attn = model(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
y=y,
|
||||
y_lengths=y_lengths,
|
||||
spks=spks,
|
||||
)
|
||||
attn = attn.cpu()
|
||||
for i in range(attn.shape[0]):
|
||||
save_durations_to_folder(
|
||||
attn[i],
|
||||
x_lengths[i].item(),
|
||||
y_lengths[i].item(),
|
||||
batch["filepaths"][i],
|
||||
output_folder,
|
||||
batch["x_texts"][i],
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-config",
|
||||
type=str,
|
||||
default="ljspeech.yaml",
|
||||
help="The name of the yaml config file under configs/data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default="32",
|
||||
help="Can have increased batch size for faster computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
action="store_true",
|
||||
default=False,
|
||||
required=False,
|
||||
help="force overwrite the file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint file to load the model from",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output folder to save the data statistics",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
action="store_true",
|
||||
help="Use CPU for inference, not recommended (default: use GPU if available)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||
cfg = compose(
|
||||
config_name=args.input_config, return_hydra_config=True, overrides=[]
|
||||
)
|
||||
|
||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||
|
||||
with open_dict(cfg):
|
||||
del cfg["hydra"]
|
||||
del cfg["_target_"]
|
||||
cfg["seed"] = 1234
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(
|
||||
os.path.join(root_path, cfg["train_filelist_path"])
|
||||
)
|
||||
cfg["valid_filelist_path"] = str(
|
||||
os.path.join(root_path, cfg["valid_filelist_path"])
|
||||
)
|
||||
cfg["load_durations"] = False
|
||||
|
||||
if args.output_folder is not None:
|
||||
output_folder = Path(args.output_folder)
|
||||
else:
|
||||
output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
|
||||
|
||||
print(f"Output folder set to: {output_folder}")
|
||||
|
||||
if os.path.exists(output_folder) and not args.force:
|
||||
print("Folder already exists. Use -f to force overwrite")
|
||||
sys.exit(1)
|
||||
|
||||
output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(
|
||||
f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}"
|
||||
)
|
||||
print("Loading model...")
|
||||
device = get_device(args)
|
||||
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
|
||||
|
||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||
text_mel_datamodule.setup()
|
||||
try:
|
||||
print("Computing stats for training set if exists...")
|
||||
train_dataloader = text_mel_datamodule.train_dataloader()
|
||||
compute_durations(train_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No training set found")
|
||||
|
||||
try:
|
||||
print("Computing stats for validation set if exists...")
|
||||
val_dataloader = text_mel_datamodule.val_dataloader()
|
||||
compute_durations(val_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No validation set found")
|
||||
|
||||
try:
|
||||
print("Computing stats for test set if exists...")
|
||||
test_dataloader = text_mel_datamodule.test_dataloader()
|
||||
compute_durations(test_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No test set found")
|
||||
|
||||
print(f"[+] Done! Data statistics saved to: {output_folder}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Helps with generating durations for the dataset to train other architectures
|
||||
# that cannot learn to align due to limited size of dataset
|
||||
# Example usage:
|
||||
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
|
||||
# This will create a folder in data/processed_data/durations/ljspeech with the durations
|
||||
main()
|
@ -1,60 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import hydra
|
||||
from lightning import Callback
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from matcha.utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
||||
"""Instantiates callbacks from config.
|
||||
|
||||
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
||||
:return: A list of instantiated callbacks.
|
||||
"""
|
||||
callbacks: List[Callback] = []
|
||||
|
||||
if not callbacks_cfg:
|
||||
log.warning("No callback configs found! Skipping..")
|
||||
return callbacks
|
||||
|
||||
if not isinstance(callbacks_cfg, DictConfig):
|
||||
raise TypeError("Callbacks config must be a DictConfig!")
|
||||
|
||||
for _, cb_conf in callbacks_cfg.items():
|
||||
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
||||
log.info(
|
||||
f"Instantiating callback <{cb_conf._target_}>"
|
||||
) # pylint: disable=protected-access
|
||||
callbacks.append(hydra.utils.instantiate(cb_conf))
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
||||
"""Instantiates loggers from config.
|
||||
|
||||
:param logger_cfg: A DictConfig object containing logger configurations.
|
||||
:return: A list of instantiated loggers.
|
||||
"""
|
||||
logger: List[Logger] = []
|
||||
|
||||
if not logger_cfg:
|
||||
log.warning("No logger configs found! Skipping...")
|
||||
return logger
|
||||
|
||||
if not isinstance(logger_cfg, DictConfig):
|
||||
raise TypeError("Logger config must be a DictConfig!")
|
||||
|
||||
for _, lg_conf in logger_cfg.items():
|
||||
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
||||
log.info(
|
||||
f"Instantiating logger <{lg_conf._target_}>"
|
||||
) # pylint: disable=protected-access
|
||||
logger.append(hydra.utils.instantiate(lg_conf))
|
||||
|
||||
return logger
|
@ -1,57 +0,0 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from lightning.pytorch.utilities import rank_zero_only
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from matcha.utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
||||
"""Controls which config parts are saved by Lightning loggers.
|
||||
|
||||
Additionally saves:
|
||||
- Number of model parameters
|
||||
|
||||
:param object_dict: A dictionary containing the following objects:
|
||||
- `"cfg"`: A DictConfig object containing the main config.
|
||||
- `"model"`: The Lightning model.
|
||||
- `"trainer"`: The Lightning trainer.
|
||||
"""
|
||||
hparams = {}
|
||||
|
||||
cfg = OmegaConf.to_container(object_dict["cfg"])
|
||||
model = object_dict["model"]
|
||||
trainer = object_dict["trainer"]
|
||||
|
||||
if not trainer.logger:
|
||||
log.warning("Logger not found! Skipping hyperparameter logging...")
|
||||
return
|
||||
|
||||
hparams["model"] = cfg["model"]
|
||||
|
||||
# save number of model parameters
|
||||
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
||||
hparams["model/params/trainable"] = sum(
|
||||
p.numel() for p in model.parameters() if p.requires_grad
|
||||
)
|
||||
hparams["model/params/non_trainable"] = sum(
|
||||
p.numel() for p in model.parameters() if not p.requires_grad
|
||||
)
|
||||
|
||||
hparams["data"] = cfg["data"]
|
||||
hparams["trainer"] = cfg["trainer"]
|
||||
|
||||
hparams["callbacks"] = cfg.get("callbacks")
|
||||
hparams["extras"] = cfg.get("extras")
|
||||
|
||||
hparams["task_name"] = cfg.get("task_name")
|
||||
hparams["tags"] = cfg.get("tags")
|
||||
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
||||
hparams["seed"] = cfg.get("seed")
|
||||
|
||||
# send hparams to all loggers
|
||||
for logger in trainer.loggers:
|
||||
logger.log_hyperparams(hparams)
|
@ -1,29 +0,0 @@
|
||||
import logging
|
||||
|
||||
from lightning.pytorch.utilities import rank_zero_only
|
||||
|
||||
|
||||
def get_pylogger(name: str = __name__) -> logging.Logger:
|
||||
"""Initializes a multi-GPU-friendly python command line logger.
|
||||
|
||||
:param name: The name of the logger, defaults to ``__name__``.
|
||||
|
||||
:return: A logger object.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
# this ensures all logging levels get marked with the rank zero decorator
|
||||
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
||||
logging_levels = (
|
||||
"debug",
|
||||
"info",
|
||||
"warning",
|
||||
"error",
|
||||
"exception",
|
||||
"fatal",
|
||||
"critical",
|
||||
)
|
||||
for level in logging_levels:
|
||||
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
||||
|
||||
return logger
|
@ -1,103 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import rich
|
||||
import rich.syntax
|
||||
import rich.tree
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from lightning.pytorch.utilities import rank_zero_only
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from matcha.utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def print_config_tree(
|
||||
cfg: DictConfig,
|
||||
print_order: Sequence[str] = (
|
||||
"data",
|
||||
"model",
|
||||
"callbacks",
|
||||
"logger",
|
||||
"trainer",
|
||||
"paths",
|
||||
"extras",
|
||||
),
|
||||
resolve: bool = False,
|
||||
save_to_file: bool = False,
|
||||
) -> None:
|
||||
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
||||
|
||||
:param cfg: A DictConfig composed by Hydra.
|
||||
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
|
||||
"callbacks", "logger", "trainer", "paths", "extras")``.
|
||||
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
||||
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
|
||||
"""
|
||||
style = "dim"
|
||||
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
||||
|
||||
queue = []
|
||||
|
||||
# add fields from `print_order` to queue
|
||||
for field in print_order:
|
||||
_ = (
|
||||
queue.append(field)
|
||||
if field in cfg
|
||||
else log.warning(
|
||||
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
||||
)
|
||||
)
|
||||
|
||||
# add all the other fields to queue (not specified in `print_order`)
|
||||
for field in cfg:
|
||||
if field not in queue:
|
||||
queue.append(field)
|
||||
|
||||
# generate config tree from queue
|
||||
for field in queue:
|
||||
branch = tree.add(field, style=style, guide_style=style)
|
||||
|
||||
config_group = cfg[field]
|
||||
if isinstance(config_group, DictConfig):
|
||||
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
||||
else:
|
||||
branch_content = str(config_group)
|
||||
|
||||
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
||||
|
||||
# print config tree
|
||||
rich.print(tree)
|
||||
|
||||
# save config tree to file
|
||||
if save_to_file:
|
||||
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
||||
rich.print(tree, file=file)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
||||
"""Prompts user to input tags from command line if no tags are provided in config.
|
||||
|
||||
:param cfg: A DictConfig composed by Hydra.
|
||||
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
|
||||
"""
|
||||
if not cfg.get("tags"):
|
||||
if "id" in HydraConfig().cfg.hydra.job:
|
||||
raise ValueError("Specify tags before launching a multirun!")
|
||||
|
||||
log.warning("No tags provided in config. Prompting user to input tags...")
|
||||
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
||||
tags = [t.strip() for t in tags.split(",") if t != ""]
|
||||
|
||||
with open_dict(cfg):
|
||||
cfg.tags = tags
|
||||
|
||||
log.info(f"Tags: {cfg.tags}")
|
||||
|
||||
if save_to_file:
|
||||
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
||||
rich.print(cfg.tags, file=file)
|
@ -1,261 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from importlib.util import find_spec
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from omegaconf import DictConfig
|
||||
|
||||
# from matcha.utils import pylogger, rich_utils
|
||||
|
||||
# log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def extras(cfg: "DictConfig") -> None:
|
||||
"""Applies optional utilities before the task is started.
|
||||
|
||||
Utilities:
|
||||
- Ignoring python warnings
|
||||
- Setting tags from command line
|
||||
- Rich config printing
|
||||
|
||||
:param cfg: A DictConfig object containing the config tree.
|
||||
"""
|
||||
# return if no `extras` config
|
||||
if not cfg.get("extras"):
|
||||
log.warning("Extras config not found! <cfg.extras=null>")
|
||||
return
|
||||
|
||||
# disable python warnings
|
||||
if cfg.extras.get("ignore_warnings"):
|
||||
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# prompt user to input tags from command line if none are provided in the config
|
||||
if cfg.extras.get("enforce_tags"):
|
||||
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
||||
rich_utils.enforce_tags(cfg, save_to_file=True)
|
||||
|
||||
# pretty print config tree using Rich library
|
||||
if cfg.extras.get("print_config"):
|
||||
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
||||
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
||||
|
||||
|
||||
def task_wrapper(task_func: Callable) -> Callable:
|
||||
"""Optional decorator that controls the failure behavior when executing the task function.
|
||||
|
||||
This wrapper can be used to:
|
||||
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
||||
- save the exception to a `.log` file
|
||||
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
||||
- etc. (adjust depending on your needs)
|
||||
|
||||
Example:
|
||||
```
|
||||
@utils.task_wrapper
|
||||
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
...
|
||||
return metric_dict, object_dict
|
||||
```
|
||||
|
||||
:param task_func: The task function to be wrapped.
|
||||
|
||||
:return: The wrapped task function.
|
||||
"""
|
||||
|
||||
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# execute the task
|
||||
try:
|
||||
metric_dict, object_dict = task_func(cfg=cfg)
|
||||
|
||||
# things to do if exception occurs
|
||||
except Exception as ex:
|
||||
# save exception to `.log` file
|
||||
log.exception("")
|
||||
|
||||
# some hyperparameter combinations might be invalid or cause out-of-memory errors
|
||||
# so when using hparam search plugins like Optuna, you might want to disable
|
||||
# raising the below exception to avoid multirun failure
|
||||
raise ex
|
||||
|
||||
# things to always do after either success or exception
|
||||
finally:
|
||||
# display output dir path in terminal
|
||||
log.info(f"Output dir: {cfg.paths.output_dir}")
|
||||
|
||||
# always close wandb run (even if exception occurs so multirun won't fail)
|
||||
if find_spec("wandb"): # check if wandb is installed
|
||||
import wandb
|
||||
|
||||
if wandb.run:
|
||||
log.info("Closing wandb!")
|
||||
wandb.finish()
|
||||
|
||||
return metric_dict, object_dict
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
|
||||
"""Safely retrieves value of the metric logged in LightningModule.
|
||||
|
||||
:param metric_dict: A dict containing metric values.
|
||||
:param metric_name: The name of the metric to retrieve.
|
||||
:return: The value of the metric.
|
||||
"""
|
||||
if not metric_name:
|
||||
log.info("Metric name is None! Skipping metric value retrieval...")
|
||||
return None
|
||||
|
||||
if metric_name not in metric_dict:
|
||||
raise ValueError(
|
||||
f"Metric value not found! <metric_name={metric_name}>\n"
|
||||
"Make sure metric name logged in LightningModule is correct!\n"
|
||||
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
||||
)
|
||||
|
||||
metric_value = metric_dict[metric_name].item()
|
||||
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
||||
|
||||
return metric_value
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
# Adds blank symbol
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def save_figure_to_numpy(fig):
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
return data
|
||||
|
||||
|
||||
def plot_tensor(tensor):
|
||||
plt.style.use("default")
|
||||
fig, ax = plt.subplots(figsize=(12, 3))
|
||||
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def save_plot(tensor, savepath):
|
||||
plt.style.use("default")
|
||||
fig, ax = plt.subplots(figsize=(12, 3))
|
||||
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
plt.savefig(savepath)
|
||||
plt.close()
|
||||
|
||||
|
||||
def to_numpy(tensor):
|
||||
if isinstance(tensor, np.ndarray):
|
||||
return tensor
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
return tensor.detach().cpu().numpy()
|
||||
elif isinstance(tensor, list):
|
||||
return np.array(tensor)
|
||||
else:
|
||||
raise TypeError("Unsupported type for conversion to numpy array")
|
||||
|
||||
|
||||
def get_user_data_dir(appname="matcha_tts"):
|
||||
"""
|
||||
Args:
|
||||
appname (str): Name of application
|
||||
|
||||
Returns:
|
||||
Path: path to user data directory
|
||||
"""
|
||||
|
||||
MATCHA_HOME = os.environ.get("MATCHA_HOME")
|
||||
if MATCHA_HOME is not None:
|
||||
ans = Path(MATCHA_HOME).expanduser().resolve(strict=False)
|
||||
elif sys.platform == "win32":
|
||||
import winreg # pylint: disable=import-outside-toplevel
|
||||
|
||||
key = winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER,
|
||||
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
|
||||
)
|
||||
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
|
||||
ans = Path(dir_).resolve(strict=False)
|
||||
elif sys.platform == "darwin":
|
||||
ans = Path("~/Library/Application Support/").expanduser()
|
||||
else:
|
||||
ans = Path.home().joinpath(".local/share")
|
||||
|
||||
final_path = ans.joinpath(appname)
|
||||
final_path.mkdir(parents=True, exist_ok=True)
|
||||
return final_path
|
||||
|
||||
|
||||
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
||||
import gdown
|
||||
import wget
|
||||
|
||||
if Path(checkpoint_path).exists():
|
||||
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
||||
print(f"[+] Model already present at {checkpoint_path}!")
|
||||
return
|
||||
log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
|
||||
print(f"[-] Model not found at {checkpoint_path}! Will download it")
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if not use_wget:
|
||||
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
||||
else:
|
||||
wget.download(url=url, out=checkpoint_path)
|
||||
|
||||
|
||||
def get_phoneme_durations(durations, phones):
|
||||
prev = durations[0]
|
||||
merged_durations = []
|
||||
# Convolve with stride 2
|
||||
for i in range(1, len(durations), 2):
|
||||
if i == len(durations) - 2:
|
||||
# if it is last take full value
|
||||
next_half = durations[i + 1]
|
||||
else:
|
||||
next_half = ceil(durations[i + 1] / 2)
|
||||
|
||||
curr = prev + durations[i] + next_half
|
||||
prev = durations[i + 1] - next_half
|
||||
merged_durations.append(curr)
|
||||
|
||||
assert len(phones) == len(merged_durations)
|
||||
assert len(merged_durations) == (len(durations) - 1) // 2
|
||||
|
||||
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
|
||||
start = torch.tensor(0)
|
||||
duration_json = []
|
||||
for i, duration in enumerate(merged_durations):
|
||||
duration_json.append(
|
||||
{
|
||||
phones[i]: {
|
||||
"starttime": start.item(),
|
||||
"endtime": duration.item(),
|
||||
"duration": duration.item() - start.item(),
|
||||
}
|
||||
}
|
||||
)
|
||||
start = duration
|
||||
|
||||
assert list(duration_json[-1].values())[0]["endtime"] == sum(
|
||||
durations
|
||||
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
|
||||
return duration_json
|
Loading…
x
Reference in New Issue
Block a user