import argparse import collections import json import logging import os import pathlib import random import re import subprocess from collections import defaultdict # from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path # from shutil import copyfile from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union import torch import torch.distributed as dist from torch.utils.tensorboard import SummaryWriter Pathlike = Union[str, Path] def get_world_size(): if "WORLD_SIZE" in os.environ: return int(os.environ["WORLD_SIZE"]) if dist.is_available() and dist.is_initialized(): return dist.get_world_size() else: return 1 def get_rank(): if "RANK" in os.environ: return int(os.environ["RANK"]) elif dist.is_available() and dist.is_initialized(): return dist.get_rank() else: return 0 def get_local_rank(): if "LOCAL_RANK" in os.environ: return int(os.environ["LOCAL_RANK"]) elif dist.is_available() and dist.is_initialized(): return dist.get_local_rank() else: return 0 def str2bool(v): """Used in argparse.ArgumentParser.add_argument to indicate that a type is a bool type and user can enter - yes, true, t, y, 1, to represent True - no, false, f, n, 0, to represent False See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa """ if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") class AttributeDict(dict): def __getattr__(self, key): if key in self: return self[key] raise AttributeError(f"No such attribute '{key}'") def __setattr__(self, key, value): self[key] = value def __delattr__(self, key): if key in self: del self[key] return raise AttributeError(f"No such attribute '{key}'") def __str__(self, indent: int = 2): tmp = {} for k, v in self.items(): # PosixPath is ont JSON serializable if isinstance(v, pathlib.Path) or isinstance(v, torch.device): v = str(v) tmp[k] = v return json.dumps(tmp, indent=indent, sort_keys=True) def setup_logger( log_filename: Pathlike, log_level: str = "info", use_console: bool = True, ) -> None: """Setup log level. Args: log_filename: The filename to save the log. log_level: The log level to use, e.g., "debug", "info", "warning", "error", "critical" use_console: True to also print logs to console. """ now = datetime.now() date_time = now.strftime("%Y-%m-%d-%H-%M-%S") if dist.is_available() and dist.is_initialized(): world_size = dist.get_world_size() rank = dist.get_rank() formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa log_filename = f"{log_filename}-{date_time}-{rank}" else: formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" os.makedirs(os.path.dirname(log_filename), exist_ok=True) level = logging.ERROR if log_level == "debug": level = logging.DEBUG elif log_level == "info": level = logging.INFO elif log_level == "warning": level = logging.WARNING elif log_level == "critical": level = logging.CRITICAL logging.basicConfig( filename=log_filename, format=formatter, level=level, filemode="w", force=True, ) if use_console: console = logging.StreamHandler() console.setLevel(level) console.setFormatter(logging.Formatter(formatter)) logging.getLogger("").addHandler(console) class MetricsTracker(collections.defaultdict): def __init__(self): # Passing the type 'int' to the base-class constructor # makes undefined items default to int() which is zero. # This class will play a role as metrics tracker. # It can record many metrics, including but not limited to loss. super(MetricsTracker, self).__init__(int) def __add__(self, other: "MetricsTracker") -> "MetricsTracker": ans = MetricsTracker() for k, v in self.items(): ans[k] = v for k, v in other.items(): if v - v == 0: ans[k] = ans[k] + v return ans def __mul__(self, alpha: float) -> "MetricsTracker": ans = MetricsTracker() for k, v in self.items(): ans[k] = v * alpha return ans def __str__(self) -> str: ans_frames = "" ans_utterances = "" for k, v in self.norm_items(): norm_value = "%.4g" % v if "utt_" not in k: ans_frames += str(k) + "=" + str(norm_value) + ", " else: ans_utterances += str(k) + "=" + str(norm_value) if k == "utt_duration": ans_utterances += " frames, " elif k == "utt_pad_proportion": ans_utterances += ", " else: raise ValueError(f"Unexpected key: {k}") frames = "%.2f" % self["frames"] ans_frames += "over " + str(frames) + " frames. " if ans_utterances != "": utterances = "%.2f" % self["utterances"] ans_utterances += "over " + str(utterances) + " utterances." return ans_frames + ans_utterances def norm_items(self) -> List[Tuple[str, float]]: """ Returns a list of pairs, like: [('ctc_loss', 0.1), ('att_loss', 0.07)] """ num_frames = self["frames"] if "frames" in self else 1 num_utterances = self["utterances"] if "utterances" in self else 1 ans = [] for k, v in self.items(): if k == "frames" or k == "utterances": continue norm_value = ( float(v) / num_frames if "utt_" not in k else float(v) / num_utterances ) ans.append((k, norm_value)) return ans def reduce(self, device): """ Reduce using torch.distributed, which I believe ensures that all processes get the total. """ keys = sorted(self.keys()) s = torch.tensor([float(self[k]) for k in keys], device=device) dist.all_reduce(s, op=dist.ReduceOp.SUM) for k, v in zip(keys, s.cpu().tolist()): self[k] = v def write_summary( self, tb_writer: SummaryWriter, prefix: str, batch_idx: int, ) -> None: """Add logging information to a TensorBoard writer. Args: tb_writer: a TensorBoard writer prefix: a prefix for the name of the loss, e.g. "train/valid_", or "train/current_" batch_idx: The current batch index, used as the x-axis of the plot. """ for k, v in self.norm_items(): tb_writer.add_scalar(prefix + k, v, batch_idx)