# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey # Zengwei Yao # Mingshuang Luo, # Zengrui Jin,) # # See ../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import random from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import Tensor, nn class TensorDiagnosticOptions(object): """Options object for tensor diagnostics: Args: max_eig_dim: The maximum dimension for which we print out eigenvalues (limited for speed reasons). """ def __init__(self, max_eig_dim: int = 512): self.max_eig_dim = max_eig_dim def dim_is_summarized(self, size: int): return size > 10 and size != 31 def get_tensor_stats( x: Tensor, dim: int, stats_type: str, ) -> Tuple[Tensor, int]: """ Returns the specified transformation of the Tensor (either x or x.abs() or (x > 0), summed over all but the index `dim`. Args: x: Tensor, tensor to be analyzed dim: Dimension with 0 <= dim < x.ndim stats_type: The stats_type includes several types: "abs" -> take abs() before summing "positive" -> take (x > 0) before summing "rms" -> square before summing, we'll take sqrt later "value" -> just sum x itself "max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing "rms-sort" -> this is a bit different than the others, it's based on computing the rms over the specified dim and returning percentiles of the result (11 of them). Returns: stats: a Tensor of shape (x.shape[dim],). count: an integer saying how many items were counted in each element of stats. """ if stats_type == "rms-sort": rms = (x**2).mean(dim=dim).sqrt() rms = rms.flatten() rms = rms.sort()[0] rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)] count = 1.0 return rms, count count = x.numel() // x.shape[dim] if stats_type == "eigs": x = x.transpose(dim, -1) x = x.reshape(-1, x.shape[-1]) # shape of returned tensor: (s, s), # where s is size of dimension `dim` of original x. return torch.matmul(x.transpose(0, 1), x), count elif stats_type == "abs": x = x.abs() elif stats_type == "rms": x = x**2 elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: assert stats_type in ["value", "max", "min"] sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: if stats_type == "max": for dim in reversed(sum_dims): x = torch.max(x, dim=dim)[0] elif stats_type == "min": for dim in reversed(sum_dims): x = torch.min(x, dim=dim)[0] else: x = torch.sum(x, dim=sum_dims) x = x.flatten().clone() return x, count @dataclass class TensorAndCount: tensor: Tensor count: int class TensorDiagnostic(object): """This class is not directly used by the user, it is responsible for collecting diagnostics for a module or parameter tensor of a torch.nn.Module. Args: opts: Options object. name: The name associated with this diagnostics object, will probably be {module_name}.X where X is "output" or "grad", or {parameter_name}.Y where Y is param_value or param_grad. """ def __init__(self, opts: TensorDiagnosticOptions, name: str): self.opts = opts self.name = name self.class_name = None # will assign in accumulate() self.stats = None # we'll later assign a list to self.stats. # It's a list of dicts, indexed by dim (i.e. by the # axis of the tensor). The dicts, in turn, are # indexed by `stats-type` which are strings in # ["abs", "max", "min", "positive", "value", "rms"]. # scalar_stats contains some analysis of the activations and gradients, self.scalar_stats = None # the keys into self.stats[dim] are strings, whose values can be # "abs", "max", "min" ,"value", "positive", "rms", "value". # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount, # containing a tensor and its associated count (which is the sum of the other dims # that we aggregated over, e.g. the number of frames and/or batch elements and/or # channels. # ... we actually accumulate the Tensors / counts any time we have the same-dim tensor, # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. def accumulate(self, x, class_name: Optional[str] = None): """ Accumulate tensors. """ if class_name is not None: self.class_name = class_name if isinstance(x, Tuple): x = x[0] if not isinstance(x, Tensor): return if x.numel() == 0: # for empty tensor return x = x.detach().clone() if x.ndim == 0: x = x.unsqueeze(0) ndim = x.ndim if self.stats is None: self.stats = [dict() for _ in range(ndim)] for dim in range(ndim): this_dim_stats = self.stats[dim] if ndim > 1: # rms-sort is different from the others, it's based on summing over just this # dim, then sorting and returning the percentiles. stats_types = [ "abs", "max", "min", "positive", "value", "rms", "rms-sort", ] if x.shape[dim] <= self.opts.max_eig_dim: stats_types.append("eigs") else: stats_types = ["value", "abs", "max", "min"] for stats_type in stats_types: stats, count = get_tensor_stats(x, dim, stats_type) if stats_type not in this_dim_stats: this_dim_stats[stats_type] = [] # list of TensorAndCount done = False if this_dim_stats[stats_type] is None: # we can reach here if we detected for stats_type "eigs" that # where was more than one different size for this dim. Then we # disable accumulating this stats type, as it uses too much memory. continue for s in this_dim_stats[stats_type]: if s.tensor.shape == stats.shape: if stats_type == "max": s.tensor = torch.maximum(s.tensor, stats) elif stats_type == "min": s.tensor = torch.minimum(s.tensor, stats) else: assert stats_type != "max" s.tensor += stats s.count += count done = True break if not done: if this_dim_stats[stats_type] != [] and stats_type == "eigs": # >1 size encountered on this dim, e.g. it's a batch or time dimension, # don't accumulat "eigs" stats type, it uses too much memory this_dim_stats[stats_type] = None else: this_dim_stats[stats_type].append(TensorAndCount(stats, count)) def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" if self.stats is None: print(f"Warning: the stats of {self.name} is None.") return for dim, this_dim_stats in enumerate(self.stats): if "rms" in this_dim_stats and "value" in this_dim_stats: # produce "stddev" stats, which is centered RMS. rms_stats_list = this_dim_stats["rms"] value_stats_list = this_dim_stats["value"] if len(rms_stats_list) == len(value_stats_list): stddev_stats_list = [] for r, v in zip(rms_stats_list, value_stats_list): stddev_stats_list.append( # r.count and v.count should be the same, but we don't check this. TensorAndCount( r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), r.count, ) ) this_dim_stats["stddev"] = stddev_stats_list for stats_type, stats_list in this_dim_stats.items(): # stats_type could be "rms", "value", "abs", "eigs", "positive", "min" or "max". # "stats_list" could be a list of TensorAndCount (one list per distinct tensor # shape of the stats), or None if stats_list is None: assert stats_type == "eigs" continue def get_count(count): return 1 if stats_type in ["max", "min"] else count if len(stats_list) == 1: stats = stats_list[0].tensor / get_count(stats_list[0].count) else: # a dimension that has variable size in different nnet # forwards, e.g. a time dimension in an ASR model. stats = torch.cat( [x.tensor / get_count(x.count) for x in stats_list], dim=0 ) if stats_type == "eigs": try: if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): eigs, _ = torch.linalg.eigh(stats) else: eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa print("Error getting eigenvalues, trying another method.") if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"): eigs, _ = torch.linalg.eig(stats) eigs = eigs.abs() else: eigs, _ = torch.eig(stats) eigs = eigs.norm(dim=1) stats = eigs.sqrt() # sqrt so it reflects data magnitude, like stddev- not variance if stats_type in ["rms", "stddev"]: # we stored the square; after aggregation we need to take sqrt. stats = stats.sqrt() # if `summarize` we print percentiles of the stats; else, # we print out individual elements. summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( stats.numel() ) if summarize: # usually `summarize` will be true # print out percentiles. stats = stats.sort()[0] num_percentiles = 10 size = stats.numel() percentiles = [] for i in range(num_percentiles + 1): index = (i * (size - 1)) // num_percentiles percentiles.append(stats[index].item()) percentiles = ["%.2g" % x for x in percentiles] percentiles = " ".join(percentiles) ans = f"percentiles: [{percentiles}]" else: ans = stats.tolist() ans = ["%.2g" % x for x in ans] ans = "[" + " ".join(ans) + "]" if stats_type in ["value", "rms", "stddev", "eigs"]: # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue # can be attributed to the mean of the distribution. norm = (stats**2).sum().sqrt().item() ans += f", norm={norm:.2g}" mean = stats.mean().item() rms = (stats**2).mean().sqrt().item() ans += f", mean={mean:.3g}, rms={rms:.3g}" # OK, "ans" contains the actual stats, e.g. # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5" sizes = [x.tensor.shape[0] for x in stats_list] size_str = ( f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" ) maybe_class_name = ( f" type={self.class_name}," if self.class_name is not None else "" ) print( f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) class ScalarDiagnostic(object): """This class is not directly used by the user, it is responsible for collecting diagnostics for a single module (subclass of torch.nn.Module) that represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc. """ def __init__(self, opts: TensorDiagnosticOptions, name: str): self.opts = opts self.name = name self.class_name = None # will assign in accumulate() self.is_forward_pass = True self.tick_scale = None self.saved_inputs = [] self.is_ok = True self.counts = None self.sum_grad = None self.sum_gradsq = None self.sum_abs_grad = None def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): """ Called in forward pass. """ if not self.is_forward_pass: # in case we did a forward pass without a backward pass, for some reason. self.saved_inputs = [] self.is_forward_pass = True if class_name is not None: self.class_name = class_name if not self.is_ok: return limit = 10 if len(self.saved_inputs) > limit: print( f"ERROR: forward pass called for this module over {limit} times with no backward pass. " f" Will not accumulate scalar stats." ) self.is_ok = False return self.saved_inputs.append(x) def accumulate_output_grad(self, grad: Tensor): if not self.is_ok: return if self.is_forward_pass: self.is_forward_pass = False last_shape = ( "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape ) if len(self.saved_inputs) == 0 or grad.shape != last_shape: print( f"ERROR: shape mismatch or no forward activation present when backward " f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" f", shape-of-last-saved-input={last_shape}" ) self.is_ok = False return x = self.saved_inputs.pop() self.process_input_and_grad(x, grad) def process_input_and_grad(self, x: Tensor, grad: Tensor): assert x.shape == grad.shape x = x.flatten() grad = grad.flatten() num_ticks_per_side = 256 if self.tick_scale is None: x_abs_sorted = x.abs().sort()[0] # take the 98th percentile as the largest value we count separately. index = int(x.numel() * 0.98) self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1] self.counts = torch.zeros( 2 * num_ticks_per_side, dtype=torch.long, device=x.device ) self.sum_grad = torch.zeros( 2 * num_ticks_per_side, dtype=torch.double, device=x.device ) # sum_gradsq is for getting error bars. self.sum_gradsq = torch.zeros( 2 * num_ticks_per_side, dtype=torch.double, device=x.device ) self.sum_abs_grad = torch.zeros( 2 * num_ticks_per_side, dtype=torch.double, device=x.device ) # this will round down. x = (x / self.tick_scale).to(torch.long) x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1) x = x + num_ticks_per_side self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) self.sum_gradsq.index_add_( dim=0, index=x, source=(grad * grad).to(torch.double) ) self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) def print_diagnostics(self): """Print diagnostics.""" if self.is_ok is False or self.counts is None: print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") return counts = self.counts.to("cpu") sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32) sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32) sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32) counts_cumsum = counts.cumsum(dim=0) counts_tot = counts_cumsum[-1] # subdivide the distribution up into `num_bins` intervals for analysis, for greater # statistical significance. each bin corresponds to multiple of the original 'tick' intervals. num_bins = 20 # integer division counts_per_bin = (counts_tot // num_bins) + 1 bin_indexes = counts_cumsum // counts_per_bin bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long) bin_counts = torch.zeros(num_bins, dtype=torch.long) bin_counts.index_add_(dim=0, index=bin_indexes, source=counts) bin_grad = torch.zeros(num_bins) bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad) bin_gradsq = torch.zeros(num_bins) bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) bin_abs_grad = torch.zeros(num_bins) bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) avg_grad = bin_grad / bin_counts avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() bin_boundary_counts = ( torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin ) bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) # boundaries are the "x" values between the bins, e.g. corresponding to the # locations of percentiles of the distribution. num_ticks_per_side = counts.numel() // 2 bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale bin_grad = bin_grad / (bin_counts + 1) bin_conf_interval = bin_gradsq.sqrt() / ( bin_counts + 1 ) # consider this a standard deviation. # bin_grad / bin_abs_grad will give us a sense for how important in a practical sense, # the gradients are. bin_abs_grad = bin_abs_grad / (bin_counts + 1) bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20) bin_conf = bin_grad / (bin_conf_interval + 1.0e-20) def tensor_to_str(x: Tensor): x = ["%.2g" % f for f in x] x = "[" + " ".join(x) + "]" return x maybe_class_name = ( f" type={self.class_name}," if self.class_name is not None else "" ) print( f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, " f"rel_grad={tensor_to_str(bin_rel_grad)}, grad_conf={tensor_to_str(bin_conf)}" ) class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. Args: opts: Options object. """ def __init__(self, opts: Optional[TensorDiagnosticOptions] = None): # In this dictionary, the keys are tensors names and the values # are corresponding TensorDiagnostic objects. if opts is None: self.opts = TensorDiagnosticOptions() else: self.opts = opts self.diagnostics = dict() def __getitem__(self, name: str): T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic if name not in self.diagnostics: self.diagnostics[name] = T(self.opts, name) return self.diagnostics[name] def print_diagnostics(self): """Print diagnostics for each tensor.""" for k in sorted(self.diagnostics.keys()): self.diagnostics[k].print_diagnostics() def get_class_name(module: nn.Module): ans = type(module).__name__ # we put the below in try blocks in case anyone is using a different version of these modules that # might have different member names. if ans == "Balancer" or ans == "ActivationBalancer": try: ans += f"[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]" except: pass elif ans == "AbsValuePenalizer": try: ans += f"[{module.limit}]" except: pass return ans def attach_diagnostics( model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None ) -> ModelDiagnostic: """Attach a ModelDiagnostic object to the model by 1) registering forward hook and backward hook on each module, to accumulate its output tensors and gradient tensors, respectively; 2) registering backward hook on each module parameter, to accumulate its values and gradients. Args: model: the model to be analyzed. opts: Options object. Returns: The ModelDiagnostic object attached to the model. """ ans = ModelDiagnostic(opts) for name, module in model.named_modules(): if name == "": name = "" # Setting model_diagnostic=ans and n=name below, instead of trying to # capture the variables, ensures that we use the current values. # (this matters for `name`, since the variable gets overwritten). # These closures don't really capture by value, only by # "the final value the variable got in the function" :-( def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64, ): _model_diagnostic[f"{_name}.output"].accumulate( _output, class_name=get_class_name(_module) ) elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64, ): _model_diagnostic[f"{_name}.output[{i}]"].accumulate( o, class_name=get_class_name(_module) ) def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64, ): _model_diagnostic[f"{_name}.grad"].accumulate( _output, class_name=get_class_name(_module) ) elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64, ): _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( o, class_name=get_class_name(_module) ) module.register_forward_hook(forward_hook) if hasattr(module, "register_full_backward_hook"): module.register_full_backward_hook(backward_hook) else: module.register_backward_hook(backward_hook) if type(module).__name__ in [ "Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh", ]: # For these specific module types, accumulate some additional diagnostics # that can help us improve the activation function. These require a lot of memory, # to save the forward activations, so limit this to some select classes. # Note: this will not work correctly for all model types. def scalar_forward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_input, tuple): (_input,) = _input assert isinstance(_input, Tensor) _model_diagnostic[f"{_name}.scalar"].accumulate_input( _input, class_name=get_class_name(_module) ) def scalar_backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_output, tuple): (_output,) = _output assert isinstance(_output, Tensor) _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) module.register_forward_hook(scalar_forward_hook) if hasattr(module, "register_full_backward_hook"): module.register_full_backward_hook(scalar_backward_hook) else: module.register_backward_hook(scalar_backward_hook) for name, parameter in model.named_parameters(): def param_backward_hook( grad, _parameter=parameter, _model_diagnostic=ans, _name=name ): _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) try: parameter.register_hook(param_backward_hook) except: logging.warning( f"Warning: could not register backward hook for parameter {name}, " f"it might not be differentiable." ) return ans def _test_tensor_diagnostic(): opts = TensorDiagnosticOptions(512) diagnostic = TensorDiagnostic(opts, "foo") for _ in range(10): diagnostic.accumulate(torch.randn(50, 100) * 10.0) diagnostic.print_diagnostics() model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80)) diagnostic = attach_diagnostics(model, opts) for _ in range(10): T = random.randint(200, 300) x = torch.randn(T, 100) y = model(x) y.sum().backward() diagnostic.print_diagnostics() if __name__ == "__main__": _test_tensor_diagnostic()