From 16dda9672f75953b0f4f0b727453752d4a9a35f8 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Tue, 15 Mar 2022 20:31:53 +0800 Subject: [PATCH] do some changes --- icefall/diagnostics.py | 128 +++++++++++++++++++++-------------------- 1 file changed, 66 insertions(+), 62 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index fba1c525a..fa9b98fa0 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -18,7 +18,7 @@ import random -from typing import List, Tuple +from typing import List, Optional, Tuple import torch from torch import Tensor, nn @@ -29,18 +29,14 @@ class TensorDiagnosticOptions(object): Args: memory_limit: - The maximum number of bytes per tensor + The maximum number of bytes per tensor (limits how many copies of the tensor we cache). max_eig_dim: The maximum dimension for which we print out eigenvalues (limited for speed reasons). """ - def __init__( - self, - memory_limit: int = (2 ** 20), - max_eig_dim: int = 512 - ): + def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512): self.memory_limit = memory_limit self.max_eig_dim = max_eig_dim @@ -49,24 +45,29 @@ class TensorDiagnosticOptions(object): def get_tensor_stats( - x: Tensor, dim: int, stats_type: str + x: Tensor, + dim: int, + stats_type: str, ) -> Tuple[Tensor, int]: """ Returns the specified transformation of the Tensor (either x or x.abs() or (x > 0), summed over all but the index `dim`. Args: - x: Tensor, tensor to be analyzed - dim: dimension with 0 <= dim < x.ndim + x: + Tensor, tensor to be analyzed + dim: + Dimension with 0 <= dim < x.ndim stats_type: - "abs" -> take abs() before summing - "positive" -> take (x > 0) before summing - "rms" -> square before summing, we'll take sqrt later - "value -> just sum x itself - Returns (stats, count) - where stats is a Tensor of shape (x.shape[dim],), and the count - is an integer saying how many items were counted in each element - of stats. + The stats_type includes several types: + "abs" -> take abs() before summing + "positive" -> take (x > 0) before summing + "rms" -> square before summing, we'll take sqrt later + "value -> just sum x itself + Returns: + stats: a Tensor of shape (x.shape[dim],). + count: an integer saying how many items were counted in each element + of stats. """ count = x.numel() // x.shape[dim] @@ -86,7 +87,7 @@ def get_tensor_stats( else: assert stats_type == "value" - sum_dims = [ d for d in range(x.ndim) if d != dim ] + sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: x = torch.sum(x, dim=sum_dims) x = x.flatten() @@ -102,46 +103,49 @@ def get_diagnostics_for_dim( ) -> str: """ This function gets diagnostics for a dimension of a module. + Args: - dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim - options: options object - sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value", - imdictates the type of stats - we accumulate, abs is mean absolute value, "positive" - is proportion of positive to nonnegative values, "eigs" - is eigenvalues after doing outer product on this dim, sum - over all other dimes. + dim: + the dimension to analyze, with 0 <= dim < tensors[0].ndim + options: + options object + sizes_same: + True if all the tensor sizes are the same on this dimension + stats_type: either "abs" or "positive" or "eigs" or "value", + imdictates the type of stats we accumulate, abs is mean absolute + value, "positive" is proportion of positive to nonnegative values, + "eigs" is eigenvalues after doing outer product on this dim, sum + over all other dimes. Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. Will return the empty string if the diagnostics did - not make sense to print out for this dimension, e.g. dimension - mismatch and stats_type == "eigs" + Diagnostic as a string, either percentiles or the actual values, + see the code. Will return the empty string if the diagnostics did + not make sense to print out for this dimension, e.g. dimension + mismatch and stats_type == "eigs". """ # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] - stats = [ x[0] for x in stats_and_counts ] - counts = [ x[1] for x in stats_and_counts ] + stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors] + stats = [x[0] for x in stats_and_counts] + counts = [x[1] for x in stats_and_counts] if stats_type == "eigs": try: stats = torch.stack(stats).sum(dim=0) - except: - return '' + except: # noqa + return "" count = sum(counts) stats = stats / count stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() + stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: stats = torch.stack(stats).sum(dim=0) count = sum(counts) stats = stats / count else: - stats = [ x[0] / x[1] for x in stats_and_counts ] + stats = [x[0] / x[1] for x in stats_and_counts] stats = torch.cat(stats, dim=0) - if stats_type == 'rms': + if stats_type == "rms": stats = stats.sqrt() # if `summarize` we print percentiles of the stats; else, @@ -156,13 +160,13 @@ def get_diagnostics_for_dim( for i in range(num_percentiles + 1): index = (i * (size - 1)) // num_percentiles percentiles.append(stats[index].item()) - percentiles = [ '%.2g' % x for x in percentiles ] - percentiles = ' '.join(percentiles) - ans = f'percentiles: [{percentiles}]' + percentiles = ["%.2g" % x for x in percentiles] + percentiles = " ".join(percentiles) + ans = f"percentiles: [{percentiles}]" else: ans = stats.tolist() - ans = [ '%.2g' % x for x in ans ] - ans = '[' + ' '.join(ans) + ']' + ans = ["%.2g" % x for x in ans] + ans = "[" + " ".join(ans) + "]" if stats_type == "value": # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, @@ -171,11 +175,11 @@ def get_diagnostics_for_dim( norm = (stats ** 2).sum().sqrt().item() mean = stats.mean().item() rms = (stats ** 2).mean().sqrt().item() - ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}' + ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}" else: mean = stats.mean().item() rms = (stats ** 2).mean().sqrt().item() - ans += f', mean={mean:.2g}, rms={rms:.2g}' + ans += f", mean={mean:.2g}, rms={rms:.2g}" return ans @@ -201,15 +205,15 @@ def print_diagnostics_for_dim( if tensors[0].shape[dim] <= options.max_eig_dim: stats_types.append("eigs") else: - stats_types = [ "value", "abs" ] + stats_types = ["value", "abs"] for stats_type in stats_types: - sizes = [ x.shape[dim] for x in tensors ] - sizes_same = all([ x == sizes[0] for x in sizes ]) - s = get_diagnostics_for_dim(dim, tensors, - options, sizes_same, - stats_type) - if s == '': + sizes = [x.shape[dim] for x in tensors] + sizes_same = all([x == sizes[0] for x in sizes]) + s = get_diagnostics_for_dim( + dim, tensors, options, sizes_same, stats_type + ) + if s == "": continue min_size = min(sizes) @@ -279,16 +283,13 @@ class TensorDiagnostic(object): try: device = torch.device("cuda") - torch.ones(1, 1, device) - except: + except: # noqa device = torch.device("cpu") ndim = self.saved_tensors[0].ndim tensors = [x.to(device) for x in self.saved_tensors] for dim in range(ndim): - print_diagnostics_for_dim( - self.name, dim, tensors, self.opts - ) + print_diagnostics_for_dim(self.name, dim, tensors, self.opts) class ModelDiagnostic(object): @@ -299,11 +300,14 @@ class ModelDiagnostic(object): Options object. """ - def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()): + def __init__(self, opts: Optional[TensorDiagnosticOptions] = None): # In this dictionary, the keys are tensors names and the values # are corresponding TensorDiagnostic objects. + if opts is None: + self.opts = TensorDiagnosticOptions() + else: + self.opts = opts self.diagnostics = dict() - self.opts = opts def __getitem__(self, name: str): if name not in self.diagnostics: @@ -380,7 +384,7 @@ def attach_diagnostics( def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2**20, 512) + opts = TensorDiagnosticOptions(2 ** 20, 512) diagnostic = TensorDiagnostic(opts, "foo")