import os import sys import warnings from importlib.util import find_spec from math import ceil from pathlib import Path from typing import Any, Callable, Dict, Tuple import matplotlib.pyplot as plt import numpy as np import torch # from omegaconf import DictConfig # from matcha.utils import pylogger, rich_utils # log = pylogger.get_pylogger(__name__) def extras(cfg: 'DictConfig') -> None: """Applies optional utilities before the task is started. Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing :param cfg: A DictConfig object containing the config tree. """ # return if no `extras` config if not cfg.get("extras"): log.warning("Extras config not found! ") return # disable python warnings if cfg.extras.get("ignore_warnings"): log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") # prompt user to input tags from command line if none are provided in the config if cfg.extras.get("enforce_tags"): log.info("Enforcing tags! ") rich_utils.enforce_tags(cfg, save_to_file=True) # pretty print config tree using Rich library if cfg.extras.get("print_config"): log.info("Printing config tree with Rich! ") rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) def task_wrapper(task_func: Callable) -> Callable: """Optional decorator that controls the failure behavior when executing the task function. This wrapper can be used to: - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - save the exception to a `.log` file - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - etc. (adjust depending on your needs) Example: ``` @utils.task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... return metric_dict, object_dict ``` :param task_func: The task function to be wrapped. :return: The wrapped task function. """ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # execute the task try: metric_dict, object_dict = task_func(cfg=cfg) # things to do if exception occurs except Exception as ex: # save exception to `.log` file log.exception("") # some hyperparameter combinations might be invalid or cause out-of-memory errors # so when using hparam search plugins like Optuna, you might want to disable # raising the below exception to avoid multirun failure raise ex # things to always do after either success or exception finally: # display output dir path in terminal log.info(f"Output dir: {cfg.paths.output_dir}") # always close wandb run (even if exception occurs so multirun won't fail) if find_spec("wandb"): # check if wandb is installed import wandb if wandb.run: log.info("Closing wandb!") wandb.finish() return metric_dict, object_dict return wrap def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: """Safely retrieves value of the metric logged in LightningModule. :param metric_dict: A dict containing metric values. :param metric_name: The name of the metric to retrieve. :return: The value of the metric. """ if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") return None if metric_name not in metric_dict: raise ValueError( f"Metric value not found! \n" "Make sure metric name logged in LightningModule is correct!\n" "Make sure `optimized_metric` name in `hparams_search` config is correct!" ) metric_value = metric_dict[metric_name].item() log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") return metric_value def intersperse(lst, item): # Adds blank symbol result = [item] * (len(lst) * 2 + 1) result[1::2] = lst return result def save_figure_to_numpy(fig): data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data def plot_tensor(tensor): plt.style.use("default") fig, ax = plt.subplots(figsize=(12, 3)) im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.tight_layout() fig.canvas.draw() data = save_figure_to_numpy(fig) plt.close() return data def save_plot(tensor, savepath): plt.style.use("default") fig, ax = plt.subplots(figsize=(12, 3)) im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.tight_layout() fig.canvas.draw() plt.savefig(savepath) plt.close() def to_numpy(tensor): if isinstance(tensor, np.ndarray): return tensor elif isinstance(tensor, torch.Tensor): return tensor.detach().cpu().numpy() elif isinstance(tensor, list): return np.array(tensor) else: raise TypeError("Unsupported type for conversion to numpy array") def get_user_data_dir(appname="matcha_tts"): """ Args: appname (str): Name of application Returns: Path: path to user data directory """ MATCHA_HOME = os.environ.get("MATCHA_HOME") if MATCHA_HOME is not None: ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) elif sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel key = winreg.OpenKey( winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", ) dir_, _ = winreg.QueryValueEx(key, "Local AppData") ans = Path(dir_).resolve(strict=False) elif sys.platform == "darwin": ans = Path("~/Library/Application Support/").expanduser() else: ans = Path.home().joinpath(".local/share") final_path = ans.joinpath(appname) final_path.mkdir(parents=True, exist_ok=True) return final_path def assert_model_downloaded(checkpoint_path, url, use_wget=True): import gdown import wget if Path(checkpoint_path).exists(): log.debug(f"[+] Model already present at {checkpoint_path}!") print(f"[+] Model already present at {checkpoint_path}!") return log.info(f"[-] Model not found at {checkpoint_path}! Will download it") print(f"[-] Model not found at {checkpoint_path}! Will download it") checkpoint_path = str(checkpoint_path) if not use_wget: gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) else: wget.download(url=url, out=checkpoint_path) def get_phoneme_durations(durations, phones): prev = durations[0] merged_durations = [] # Convolve with stride 2 for i in range(1, len(durations), 2): if i == len(durations) - 2: # if it is last take full value next_half = durations[i + 1] else: next_half = ceil(durations[i + 1] / 2) curr = prev + durations[i] + next_half prev = durations[i + 1] - next_half merged_durations.append(curr) assert len(phones) == len(merged_durations) assert len(merged_durations) == (len(durations) - 1) // 2 merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) start = torch.tensor(0) duration_json = [] for i, duration in enumerate(merged_durations): duration_json.append( { phones[i]: { "starttime": start.item(), "endtime": duration.item(), "duration": duration.item() - start.item(), } } ) start = duration assert list(duration_json[-1].values())[0]["endtime"] == sum( durations ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" return duration_json