From c2c46ea02311ac30bd366efafcc70c1d1f7e0efc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 18 May 2022 21:42:57 +0800 Subject: [PATCH] Update diagnostics, hopefully print more stats. # Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless4b/train.py --- icefall/diagnostics.py | 290 +++++++++++++++++------------------------ 1 file changed, 120 insertions(+), 170 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index bc8fe3069..d0a0461d0 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -19,7 +19,7 @@ import random from typing import List, Optional, Tuple - +from dataclasses import dataclass import torch from torch import Tensor, nn @@ -94,138 +94,12 @@ def get_tensor_stats( return x, count -def get_diagnostics_for_dim( - dim: int, - tensors: List[Tensor], - options: TensorDiagnosticOptions, - sizes_same: bool, - stats_type: str, -) -> str: - """ - This function gets diagnostics for a dimension of a module. - - Args: - dim: - the dimension to analyze, with 0 <= dim < tensors[0].ndim - options: - options object - sizes_same: - True if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value", - imdictates the type of stats we accumulate, abs is mean absolute - value, "positive" is proportion of positive to nonnegative values, - "eigs" is eigenvalues after doing outer product on this dim, sum - over all other dimes. - Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. Will return the empty string if the diagnostics did - not make sense to print out for this dimension, e.g. dimension - mismatch and stats_type == "eigs". - """ - - # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors] - stats = [x[0] for x in stats_and_counts] - counts = [x[1] for x in stats_and_counts] - - if stats_type == "eigs": - try: - stats = torch.stack(stats).sum(dim=0) - except: # noqa - return "" - count = sum(counts) - stats = stats / count - try: - eigs, _ = torch.symeig(stats) - stats = eigs.abs().sqrt() - except: # noqa - print("Error getting eigenvalues, trying another method.") - eigs = torch.linalg.eigvals(stats) - stats = eigs.abs().sqrt() - # sqrt so it reflects data magnitude, like stddev- not variance - elif sizes_same: - stats = torch.stack(stats).sum(dim=0) - count = sum(counts) - stats = stats / count - else: - stats = [x[0] / x[1] for x in stats_and_counts] - stats = torch.cat(stats, dim=0) - if stats_type == "rms": - stats = stats.sqrt() - - # if `summarize` we print percentiles of the stats; else, - # we print out individual elements. - summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) - if summarize: - # print out percentiles. - stats = stats.sort()[0] - num_percentiles = 10 - size = stats.numel() - percentiles = [] - for i in range(num_percentiles + 1): - index = (i * (size - 1)) // num_percentiles - percentiles.append(stats[index].item()) - percentiles = ["%.2g" % x for x in percentiles] - percentiles = " ".join(percentiles) - ans = f"percentiles: [{percentiles}]" - else: - ans = stats.tolist() - ans = ["%.2g" % x for x in ans] - ans = "[" + " ".join(ans) + "]" - if stats_type == "value": - # This norm is useful because it is strictly less than the largest - # sqrt(eigenvalue) of the variance, which we print out, and shows, - # speaking in an approximate way, how much of that largest eigenvalue - # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}" - else: - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f", mean={mean:.2g}, rms={rms:.2g}" - return ans -def print_diagnostics_for_dim( - name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions -): - """This function prints diagnostics for a dimension of a tensor. - - Args: - name: - The tensor name. - dim: - The dimension to analyze, with 0 <= dim < tensors[0].ndim. - tensors: - List of cached tensors to get the stats. - options: - Options object. - """ - - ndim = tensors[0].ndim - if ndim > 1: - stats_types = ["abs", "positive", "value", "rms"] - if tensors[0].shape[dim] <= options.max_eig_dim: - stats_types.append("eigs") - else: - stats_types = ["value", "abs"] - - for stats_type in stats_types: - sizes = [x.shape[dim] for x in tensors] - sizes_same = all([x == sizes[0] for x in sizes]) - s = get_diagnostics_for_dim( - dim, tensors, options, sizes_same, stats_type - ) - if s == "": - continue - - min_size = min(sizes) - max_size = max(sizes) - size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" - # stats_type will be "abs" or "positive". - print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") +@dataclass +class TensorAndCount: + tensor: Tensor + count: int class TensorDiagnostic(object): @@ -238,12 +112,23 @@ class TensorDiagnostic(object): name: The tensor name. """ - def __init__(self, opts: TensorDiagnosticOptions, name: str): self.name = name self.opts = opts - # A list to cache the tensors. - self.saved_tensors = [] + + + self.stats = None # we'll later assign a list to this data member. It's a list of dict. + + # the keys into self.stats[dim] are strings, whose values can be + # "abs", "value", "positive", "rms", "value". + # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount, + # containing a tensor and its associated count (which is the sum of the other dims + # that we aggregated over, e.g. the number of frames and/or batch elements and/or + # channels. + # ... we actually accumulate the Tensors / counts any time we have the same-dim tensor, + # only adding a new element to the list if there was a different dim. + # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. + def accumulate(self, x): """Accumulate tensors.""" @@ -251,50 +136,115 @@ class TensorDiagnostic(object): x = x[0] if not isinstance(x, Tensor): return - if x.device == torch.device("cpu"): - x = x.detach().clone() - else: - x = x.detach().to("cpu", non_blocking=True) - self.saved_tensors.append(x) - num = len(self.saved_tensors) - if num & (num - 1) == 0: # power of 2.. - self._limit_memory() + x = x.detach().clone() + if x.ndim == 0: + x = x.unsqueeze(0) + ndim = x.ndim + if self.stats is None: + self.stats = [ dict() for _ in range(ndim) ] - def _limit_memory(self): - """Only keep the newly cached tensors to limit memory.""" - if len(self.saved_tensors) > 1024: - self.saved_tensors = self.saved_tensors[-1024:] - return + for dim in range(ndim): + this_dim_stats = self.stats[dim] + if ndim > 1: + stats_types = ["abs", "positive", "value", "rms"] + if x.shape[dim] <= self.opts.max_eig_dim: + stats_types.append("eigs") + else: + stats_types = ["value", "abs"] + this_dict = self.stats[dim] + for stats_type in stats_types: + stats, count = get_tensor_stats(x, dim, stats_type) + if not stats_type in this_dim_stats: + this_dim_stats[stats_type] = [] # list of TensorAndCount + + done = False + if this_dim_stats[stats_type] is None: + # we can reach here if we detected for stats_type "eigs" that + # where was more than one different size for this dim. Then we + # disable accumulating this stats type, as it uses too much memory. + continue + for s in this_dim_stats[stats_type]: + if s.tensor.shape == stats.shape: + s.tensor += stats + s.count += count + done = True + break + if not done: + if this_dim_stats[stats_type] != [] and stats_type == "eigs": + # >1 size encountered on this dim, e.g. it's a batch or time dimension, + # don't accumulat "eigs" stats type, it uses too much memory + this_dim_stats[stats_type] = None + else: + this_dim_stats[stats_type].append(TensorAndCount(stats, count)) - tot_mem = 0.0 - for i in reversed(range(len(self.saved_tensors))): - tot_mem += ( - self.saved_tensors[i].numel() - * self.saved_tensors[i].element_size() - ) - if tot_mem > self.opts.memory_limit: - self.saved_tensors = self.saved_tensors[i:] - return def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" - if len(self.saved_tensors) == 0: - print("{name}: no stats".format(name=self.name)) - return + for dim, this_dim_stats in enumerate(self.stats): + for stats_type, stats_list in this_dim_stats.items(): + # stats_type could be "rms", "value", "abs", "eigs", "positive". + # "value" could be a list of TensorAndCount, or None + if stats_list is None: + assert stats_type == "eigs" + continue - if self.saved_tensors[0].ndim == 0: - # Ensure there is at least one dim. - self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] + if stats_type == "eigs": + assert len(stats_list) == 1 + stats = stats_list[0].tensor / stats_list[0].count + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: # noqa + print("Error getting eigenvalues, trying another method.") + eigs = torch.linalg.eigvals(stats) + stats = eigs.abs().sqrt() + # sqrt so it reflects data magnitude, like stddev- not variance + elif len(stats_list) == 1: + stats = stats_list[0].tensor / stats_list[0].count + else: + stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0) - try: - device = torch.device("cuda") - except: # noqa - device = torch.device("cpu") + if stats_type == "rms": + # we stored the square; after aggregation we need to take sqrt. + stats = stats.sqrt() + + # if `summarize` we print percentiles of the stats; else, + # we print out individual elements. + summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel()) + if summarize: # usually `summarize` will be true + # print out percentiles. + stats = stats.sort()[0] + num_percentiles = 10 + size = stats.numel() + percentiles = [] + for i in range(num_percentiles + 1): + index = (i * (size - 1)) // num_percentiles + percentiles.append(stats[index].item()) + percentiles = ["%.2g" % x for x in percentiles] + percentiles = " ".join(percentiles) + ans = f"percentiles: [{percentiles}]" + else: + ans = stats.tolist() + ans = ["%.2g" % x for x in ans] + ans = "[" + " ".join(ans) + "]" + if stats_type == "value": + # This norm is useful because it is strictly less than the largest + # sqrt(eigenvalue) of the variance, which we print out, and shows, + # speaking in an approximate way, how much of that largest eigenvalue + # can be attributed to the mean of the distribution. + norm = (stats ** 2).sum().sqrt().item() + ans += f", norm={norm:.2g}" + mean = stats.mean().item() + rms = (stats ** 2).mean().sqrt().item() + ans += f", mean={mean:.2g}, rms={rms:.2g}" + + # OK, "ans" contains the actual stats, e.g. + # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5" + + sizes = [x.tensor.shape[0] for x in stats_list] + size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" + print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}") - ndim = self.saved_tensors[0].ndim - tensors = [x.to(device) for x in self.saved_tensors] - for dim in range(ndim): - print_diagnostics_for_dim(self.name, dim, tensors, self.opts) class ModelDiagnostic(object):