From b9696878b427add4a8384834bb14af86a81179ab Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 26 Jul 2022 12:30:07 +0800 Subject: [PATCH] Update diagnostics stats --- icefall/diagnostics.py | 145 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 135 insertions(+), 10 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 01bf552cc..ba56a2425 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -19,7 +19,7 @@ import random from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, List import torch from torch import Tensor, nn @@ -188,14 +188,22 @@ class TensorDiagnostic(object): for dim, this_dim_stats in enumerate(self.stats): for stats_type, stats_list in this_dim_stats.items(): # stats_type could be "rms", "value", "abs", "eigs", "positive". - # "value" could be a list of TensorAndCount, or None + # "stats_list" could be a list of TensorAndCount (one list per distinct tensor + # shape of the stats), or None if stats_list is None: assert stats_type == "eigs" continue - if stats_type == "eigs": - assert len(stats_list) == 1 + if len(stats_list) == 1: stats = stats_list[0].tensor / stats_list[0].count + else: + # a dimension that has variable size in different nnet + # forwards, e.g. a time dimension in an ASR model. + stats = torch.cat( + [x.tensor / x.count for x in stats_list], dim=0 + ) + + if stats_type == "eigs": try: eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() @@ -206,12 +214,6 @@ class TensorDiagnostic(object): eigs = torch.linalg.eigvals(stats) stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - elif len(stats_list) == 1: - stats = stats_list[0].tensor / stats_list[0].count - else: - stats = torch.cat( - [x.tensor / x.count for x in stats_list], dim=0 - ) if stats_type == "rms": # we stored the square; after aggregation we need to take sqrt. @@ -264,6 +266,117 @@ class TensorDiagnostic(object): ) + def print_joint_diagnostics(self, other: 'TensorDiagnostic'): + """ + Prints diagnostics that relate to correlations between the 'basic' diagnostics + printed in print_diagnostics(). + """ + combined_name = _summarize_two_names(self.name, other.name) + # e.g. combined_name == 'foo.{param_value,param_grad}' or just 'foo.param_value' if self.name == other.name. + for dim, this_dim_stats in enumerate(self.stats): + try: + other_dim_stats = other.stats[dim] + except (TypeError, IndexError): + print(f"Continuing, dim={dim}, (0)") + continue + + output_list = [] + for stats_type, stats_list in this_dim_stats.items(): + # stats_type could be "rms", "value", "abs", "eigs", "positive". + # "stats_list" could be a list of TensorAndCount (one list per distinct tensor + # shape of the stats), or None + if stats_list is None: + continue + # work out `size_str`, will be used to print out data later.. this is the + # same for all `stats_type` values + sizes = [x.tensor.shape[0] for x in stats_list] + size_str = ( + f"{sizes[0]}" + if len(sizes) == 1 + else f"{min(sizes)}..{max(sizes)}" + ) + + if len(stats_list) == 1: + stats = stats_list[0].tensor / stats_list[0].count + else: + stats = torch.cat( + [x.tensor / x.count for x in stats_list], dim=0 + ) + + other_stats_list = other_dim_stats[stats_type] + for other_stats_type, other_stats_list in other_dim_stats.items(): + # avoid redundantly comparing a,b and b,a + if (other_stats_type > stats_type or other_stats_list is None or + len(other_stats_list) == 0 or stats_list is other_stats_list): + continue + if len(stats_list) == 1: + other_stats = other_stats_list[0].tensor / other_stats_list[0].count + size = stats.shape[0] + else: + other_stats = torch.cat( + [x.tensor / x.count for x in other_stats_list], dim=0 + ) + if other_stats.shape != stats.shape: + # e.g. stats_type == "eigs" and other_stats_type != + # "eigs" or the other way around + continue + + + if stats.ndim == 2: + # Matrices, for purposes of measuring eigenvalues. Just compute a dot-product-related + # measure of correlation. + correlation = ((stats * other_stats).sum() / + ((stats**2).sum() * (other_stats**2).sum() + 1.0e-20).sqrt()) + else: + # ndim == 1 + # Use a rank-based measure of correlation + (_, indices1) = stats.sort() + (_, indices2) = other_stats.sort() + n = stats.numel() + rank1 = ((indices1 + 0.5) / n) - 0.5 + rank2 = ((indices2 + 0.5) / n) - 0.5 + correlation = (rank1 * rank2).sum() / (rank1 * rank1).sum() + output_list.append(f'{stats_type},{other_stats_type}={correlation:.3f}') + if len(output_list) == 0: + continue + + maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" + output = f"module={combined_name}{maybe_class_name} dim={dim} size={size_str}: " + ' '.join(output_list) + print(output) + + + +def _summarize_two_names(a: str, b:str, separator: str = ',') -> str: + """ + Given 'foo.ab' and 'foo.xyz', returns 'foo.{ab,xyz}'. If a == b, + returns a. + """ + if a == b: + return a + num_common_chars = min(len(a), len(b)) + for i in range(num_common_chars): + if a[i] != b[i]: + num_common_chars = i + break + return '%s{%s%s%s}' % (a[:num_common_chars], a[num_common_chars:], + separator, b[num_common_chars:]) + +def _get_comparison_keys(k: str) -> List[str]: + """ + Gets names of diagnostic objects to compare with this one (including itself). + If k is "something.output" or "something.grad", will return ["something.output", "something.grad"] + If k is "something.param_value" or "something.param_grad", will return + """ + ending_sets = [ ['.output', '.grad'], ['.output[0]', '.grad[0]'], ['.output[1]', '.grad[1]'], + ['.output[2]', '.grad[2]'], ['.param_value', '.param_grad'] ] + for s in ending_sets: + for end in s: + if k.endswith(end): + prefix = k[:-len(end)] + return [ prefix + suffix for suffix in s] + return [k] + + class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -290,6 +403,14 @@ class ModelDiagnostic(object): """Print diagnostics for each tensor.""" for k in sorted(self.diagnostics.keys()): self.diagnostics[k].print_diagnostics() + for l in _get_comparison_keys(k): + if l >= k: # this ensures we don't print redundant correlations + # for (a,b) and (b,a), since they are symmetric. + try: + self.diagnostics[k].print_joint_diagnostics( + self.diagnostics[l]) + except KeyError: + pass def attach_diagnostics( @@ -324,6 +445,8 @@ def attach_diagnostics( def forward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): + if isinstance(_output, tuple) and len(_output) == 1: + _output = _output[0] if isinstance(_output, Tensor): _model_diagnostic[f"{_name}.output"].accumulate(_output, @@ -336,6 +459,8 @@ def attach_diagnostics( def backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): + if isinstance(_output, tuple) and len(_output) == 1: + _output = _output[0] if isinstance(_output, Tensor): _model_diagnostic[f"{_name}.grad"].accumulate(_output, class_name=type(_module).__name__)