diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 944f11f64..fba1c525a 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -1,5 +1,6 @@ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey -# Zengwei Yao) +# Zengwei Yao +# Mingshuang Luo) # # See ../LICENSE for clarification regarding multiple authors # @@ -28,51 +29,67 @@ class TensorDiagnosticOptions(object): Args: memory_limit: - The maximum number of bytes per tensor (limits how many copies - of the tensor we cache). + The maximum number of bytes per tensor + (limits how many copies of the tensor we cache). + max_eig_dim: + The maximum dimension for which we print out eigenvalues + (limited for speed reasons). """ - def __init__(self, memory_limit: int): + def __init__( + self, + memory_limit: int = (2 ** 20), + max_eig_dim: int = 512 + ): self.memory_limit = memory_limit + self.max_eig_dim = max_eig_dim def dim_is_summarized(self, size: int): return size > 10 and size != 31 -def get_sum_abs_stats( +def get_tensor_stats( x: Tensor, dim: int, stats_type: str ) -> Tuple[Tensor, int]: - """Returns the sum-of-absolute-value of this Tensor, for each index into - the specified axis/dim of the tensor. + """ + Returns the specified transformation of the Tensor (either x or x.abs() + or (x > 0), summed over all but the index `dim`. Args: - x: - Tensor, tensor to be analyzed - dim: - Dimension with 0 <= dim < x.ndim + x: Tensor, tensor to be analyzed + dim: dimension with 0 <= dim < x.ndim stats_type: - Either "mean-abs" in which case the stats represent the mean absolute - value, or "pos-ratio" in which case the stats represent the proportion - of positive values (actually: the tensor is count of positive values, - count is the count of all values). - - Returns: - (sum_abs, count) where sum_abs is a Tensor of shape (x.shape[dim],), - and the count is an integer saying how many items were counted in - each element of sum_abs. + "abs" -> take abs() before summing + "positive" -> take (x > 0) before summing + "rms" -> square before summing, we'll take sqrt later + "value -> just sum x itself + Returns (stats, count) + where stats is a Tensor of shape (x.shape[dim],), and the count + is an integer saying how many items were counted in each element + of stats. """ - if stats_type == "mean-abs": + + count = x.numel() // x.shape[dim] + + if stats_type == "eigs": + x = x.transpose(dim, -1) + x = x.reshape(-1, x.shape[-1]) + # shape of returned tensor: (s, s), + # where s is size of dimension `dim` of original x. + return torch.matmul(x.transpose(0, 1), x), count + elif stats_type == "abs": x = x.abs() - else: - assert stats_type == "pos-ratio" + elif stats_type == "rms": + x = x ** 2 + elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) + else: + assert stats_type == "value" - orig_numel = x.numel() - sum_dims = [d for d in range(x.ndim) if d != dim] - x = torch.sum(x, dim=sum_dims) - count = orig_numel // x.numel() + sum_dims = [ d for d in range(x.ndim) if d != dim ] + if len(sum_dims) > 0: + x = torch.sum(x, dim=sum_dims) x = x.flatten() - return x, count @@ -83,43 +100,55 @@ def get_diagnostics_for_dim( sizes_same: bool, stats_type: str, ) -> str: - """This function gets diagnostics for a dimension of a module. - + """ + This function gets diagnostics for a dimension of a module. Args: - dim: - The dimension to analyze, with 0 <= dim < tensors[0].ndim - tensors: - List of cached tensors to get the stats - options: - Options object - sizes_same: - True if all the tensor sizes are the same on this dimension - stats_type: either "mean-abs" or "pos-ratio", dictates the type of - stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is - proportion of positive to nonnegative values. - + dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim + options: options object + sizes_same: true if all the tensor sizes are the same on this dimension + stats_type: either "abs" or "positive" or "eigs" or "value", + imdictates the type of stats + we accumulate, abs is mean absolute value, "positive" + is proportion of positive to nonnegative values, "eigs" + is eigenvalues after doing outer product on this dim, sum + over all other dimes. Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. + Diagnostic as a string, either percentiles or the actual values, + see the code. Will return the empty string if the diagnostics did + not make sense to print out for this dimension, e.g. dimension + mismatch and stats_type == "eigs" """ # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors] - stats = [x[0] for x in stats_and_counts] - counts = [x[1] for x in stats_and_counts] - if sizes_same: + stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] + stats = [ x[0] for x in stats_and_counts ] + counts = [ x[1] for x in stats_and_counts ] + + if stats_type == "eigs": + try: + stats = torch.stack(stats).sum(dim=0) + except: + return '' + count = sum(counts) + stats = stats / count + stats, _ = torch.symeig(stats) + stats = stats.abs().sqrt() + # sqrt so it reflects data magnitude, like stddev- not variance + elif sizes_same: stats = torch.stack(stats).sum(dim=0) count = sum(counts) stats = stats / count else: - stats = [x[0] / x[1] for x in stats_and_counts] + stats = [ x[0] / x[1] for x in stats_and_counts ] stats = torch.cat(stats, dim=0) + if stats_type == 'rms': + stats = stats.sqrt() - # If `summarize` we print percentiles of the stats; - # else, we print out individual elements. + # if `summarize` we print percentiles of the stats; else, + # we print out individual elements. summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) if summarize: - # Print out percentiles. + # print out percentiles. stats = stats.sort()[0] num_percentiles = 10 size = stats.numel() @@ -127,14 +156,27 @@ def get_diagnostics_for_dim( for i in range(num_percentiles + 1): index = (i * (size - 1)) // num_percentiles percentiles.append(stats[index].item()) - percentiles = ["%.2g" % x for x in percentiles] - percentiles = " ".join(percentiles) - return f"percentiles: [{percentiles}]" + percentiles = [ '%.2g' % x for x in percentiles ] + percentiles = ' '.join(percentiles) + ans = f'percentiles: [{percentiles}]' else: - stats = stats.tolist() - stats = ["%.2g" % x for x in stats] - stats = "[" + " ".join(stats) + "]" - return stats + ans = stats.tolist() + ans = [ '%.2g' % x for x in ans ] + ans = '[' + ' '.join(ans) + ']' + if stats_type == "value": + # This norm is useful because it is strictly less than the largest + # sqrt(eigenvalue) of the variance, which we print out, and shows, + # speaking in an approximate way, how much of that largest eigenvalue + # can be attributed to the mean of the distribution. + norm = (stats ** 2).sum().sqrt().item() + mean = stats.mean().item() + rms = (stats ** 2).mean().sqrt().item() + ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}' + else: + mean = stats.mean().item() + rms = (stats ** 2).mean().sqrt().item() + ans += f', mean={mean:.2g}, rms={rms:.2g}' + return ans def print_diagnostics_for_dim( @@ -153,17 +195,27 @@ def print_diagnostics_for_dim( Options object. """ - for stats_type in ["mean-abs", "pos-ratio"]: - # stats_type will be "mean-abs" or "pos-ratio". - sizes = [x.shape[dim] for x in tensors] - sizes_same = all([x == sizes[0] for x in sizes]) - s = get_diagnostics_for_dim( - dim, tensors, options, sizes_same, stats_type - ) + ndim = tensors[0].ndim + if ndim > 1: + stats_types = ["abs", "positive", "value", "rms"] + if tensors[0].shape[dim] <= options.max_eig_dim: + stats_types.append("eigs") + else: + stats_types = [ "value", "abs" ] + + for stats_type in stats_types: + sizes = [ x.shape[dim] for x in tensors ] + sizes_same = all([ x == sizes[0] for x in sizes ]) + s = get_diagnostics_for_dim(dim, tensors, + options, sizes_same, + stats_type) + if s == '': + continue min_size = min(sizes) max_size = max(sizes) size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" + # stats_type will be "abs" or "positive". print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") @@ -225,10 +277,17 @@ class TensorDiagnostic(object): # Ensure there is at least one dim. self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] + try: + device = torch.device("cuda") + torch.ones(1, 1, device) + except: + device = torch.device("cpu") + ndim = self.saved_tensors[0].ndim + tensors = [x.to(device) for x in self.saved_tensors] for dim in range(ndim): print_diagnostics_for_dim( - self.name, dim, self.saved_tensors, self.opts + self.name, dim, tensors, self.opts ) @@ -240,7 +299,7 @@ class ModelDiagnostic(object): Options object. """ - def __init__(self, opts: TensorDiagnosticOptions): + def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()): # In this dictionary, the keys are tensors names and the values # are corresponding TensorDiagnostic objects. self.diagnostics = dict() @@ -321,7 +380,7 @@ def attach_diagnostics( def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2 ** 20) + opts = TensorDiagnosticOptions(2**20, 512) diagnostic = TensorDiagnostic(opts, "foo")