Update diagnostics.py (#254)

* update diagnostics.py

* do some changes
This commit is contained in:
Mingshuang Luo 2022-03-16 20:17:45 +08:00 committed by GitHub
parent a7643301ec
commit 518ec6414a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,6 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey
# Zengwei Yao) # Zengwei Yao
# Mingshuang Luo)
# #
# See ../LICENSE for clarification regarding multiple authors # See ../LICENSE for clarification regarding multiple authors
# #
@ -17,7 +18,7 @@
import random import random
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -28,22 +29,29 @@ class TensorDiagnosticOptions(object):
Args: Args:
memory_limit: memory_limit:
The maximum number of bytes per tensor (limits how many copies The maximum number of bytes per tensor
of the tensor we cache). (limits how many copies of the tensor we cache).
max_eig_dim:
The maximum dimension for which we print out eigenvalues
(limited for speed reasons).
""" """
def __init__(self, memory_limit: int): def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512):
self.memory_limit = memory_limit self.memory_limit = memory_limit
self.max_eig_dim = max_eig_dim
def dim_is_summarized(self, size: int): def dim_is_summarized(self, size: int):
return size > 10 and size != 31 return size > 10 and size != 31
def get_sum_abs_stats( def get_tensor_stats(
x: Tensor, dim: int, stats_type: str x: Tensor,
dim: int,
stats_type: str,
) -> Tuple[Tensor, int]: ) -> Tuple[Tensor, int]:
"""Returns the sum-of-absolute-value of this Tensor, for each index into """
the specified axis/dim of the tensor. Returns the specified transformation of the Tensor (either x or x.abs()
or (x > 0), summed over all but the index `dim`.
Args: Args:
x: x:
@ -51,28 +59,38 @@ def get_sum_abs_stats(
dim: dim:
Dimension with 0 <= dim < x.ndim Dimension with 0 <= dim < x.ndim
stats_type: stats_type:
Either "mean-abs" in which case the stats represent the mean absolute The stats_type includes several types:
value, or "pos-ratio" in which case the stats represent the proportion "abs" -> take abs() before summing
of positive values (actually: the tensor is count of positive values, "positive" -> take (x > 0) before summing
count is the count of all values). "rms" -> square before summing, we'll take sqrt later
"value -> just sum x itself
Returns: Returns:
(sum_abs, count) where sum_abs is a Tensor of shape (x.shape[dim],), stats: a Tensor of shape (x.shape[dim],).
and the count is an integer saying how many items were counted in count: an integer saying how many items were counted in each element
each element of sum_abs. of stats.
""" """
if stats_type == "mean-abs":
count = x.numel() // x.shape[dim]
if stats_type == "eigs":
x = x.transpose(dim, -1)
x = x.reshape(-1, x.shape[-1])
# shape of returned tensor: (s, s),
# where s is size of dimension `dim` of original x.
return torch.matmul(x.transpose(0, 1), x), count
elif stats_type == "abs":
x = x.abs() x = x.abs()
else: elif stats_type == "rms":
assert stats_type == "pos-ratio" x = x ** 2
elif stats_type == "positive":
x = (x > 0).to(dtype=torch.float) x = (x > 0).to(dtype=torch.float)
else:
assert stats_type == "value"
orig_numel = x.numel()
sum_dims = [d for d in range(x.ndim) if d != dim] sum_dims = [d for d in range(x.ndim) if d != dim]
if len(sum_dims) > 0:
x = torch.sum(x, dim=sum_dims) x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten() x = x.flatten()
return x, count return x, count
@ -83,43 +101,58 @@ def get_diagnostics_for_dim(
sizes_same: bool, sizes_same: bool,
stats_type: str, stats_type: str,
) -> str: ) -> str:
"""This function gets diagnostics for a dimension of a module. """
This function gets diagnostics for a dimension of a module.
Args: Args:
dim: dim:
The dimension to analyze, with 0 <= dim < tensors[0].ndim the dimension to analyze, with 0 <= dim < tensors[0].ndim
tensors:
List of cached tensors to get the stats
options: options:
Options object options object
sizes_same: sizes_same:
True if all the tensor sizes are the same on this dimension True if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats_type: either "abs" or "positive" or "eigs" or "value",
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is imdictates the type of stats we accumulate, abs is mean absolute
proportion of positive to nonnegative values. value, "positive" is proportion of positive to nonnegative values,
"eigs" is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns: Returns:
Diagnostic as a string, either percentiles or the actual values, Diagnostic as a string, either percentiles or the actual values,
see the code. see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs".
""" """
# stats_and_counts is a list of pair (Tensor, int) # stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors] stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors]
stats = [x[0] for x in stats_and_counts] stats = [x[0] for x in stats_and_counts]
counts = [x[1] for x in stats_and_counts] counts = [x[1] for x in stats_and_counts]
if sizes_same:
if stats_type == "eigs":
try:
stats = torch.stack(stats).sum(dim=0)
except: # noqa
return ""
count = sum(counts)
stats = stats / count
stats, _ = torch.symeig(stats)
stats = stats.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
else: else:
stats = [x[0] / x[1] for x in stats_and_counts] stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0) stats = torch.cat(stats, dim=0)
if stats_type == "rms":
stats = stats.sqrt()
# If `summarize` we print percentiles of the stats; # if `summarize` we print percentiles of the stats; else,
# else, we print out individual elements. # we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize: if summarize:
# Print out percentiles. # print out percentiles.
stats = stats.sort()[0] stats = stats.sort()[0]
num_percentiles = 10 num_percentiles = 10
size = stats.numel() size = stats.numel()
@ -129,12 +162,25 @@ def get_diagnostics_for_dim(
percentiles.append(stats[index].item()) percentiles.append(stats[index].item())
percentiles = ["%.2g" % x for x in percentiles] percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles) percentiles = " ".join(percentiles)
return f"percentiles: [{percentiles}]" ans = f"percentiles: [{percentiles}]"
else: else:
stats = stats.tolist() ans = stats.tolist()
stats = ["%.2g" % x for x in stats] ans = ["%.2g" % x for x in ans]
stats = "[" + " ".join(stats) + "]" ans = "[" + " ".join(ans) + "]"
return stats if stats_type == "value":
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}"
else:
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
return ans
def print_diagnostics_for_dim( def print_diagnostics_for_dim(
@ -153,17 +199,27 @@ def print_diagnostics_for_dim(
Options object. Options object.
""" """
for stats_type in ["mean-abs", "pos-ratio"]: ndim = tensors[0].ndim
# stats_type will be "mean-abs" or "pos-ratio". if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
for stats_type in stats_types:
sizes = [x.shape[dim] for x in tensors] sizes = [x.shape[dim] for x in tensors]
sizes_same = all([x == sizes[0] for x in sizes]) sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim( s = get_diagnostics_for_dim(
dim, tensors, options, sizes_same, stats_type dim, tensors, options, sizes_same, stats_type
) )
if s == "":
continue
min_size = min(sizes) min_size = min(sizes)
max_size = max(sizes) max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "abs" or "positive".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
@ -225,11 +281,15 @@ class TensorDiagnostic(object):
# Ensure there is at least one dim. # Ensure there is at least one dim.
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
try:
device = torch.device("cuda")
except: # noqa
device = torch.device("cpu")
ndim = self.saved_tensors[0].ndim ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim): for dim in range(ndim):
print_diagnostics_for_dim( print_diagnostics_for_dim(self.name, dim, tensors, self.opts)
self.name, dim, self.saved_tensors, self.opts
)
class ModelDiagnostic(object): class ModelDiagnostic(object):
@ -240,11 +300,14 @@ class ModelDiagnostic(object):
Options object. Options object.
""" """
def __init__(self, opts: TensorDiagnosticOptions): def __init__(self, opts: Optional[TensorDiagnosticOptions] = None):
# In this dictionary, the keys are tensors names and the values # In this dictionary, the keys are tensors names and the values
# are corresponding TensorDiagnostic objects. # are corresponding TensorDiagnostic objects.
self.diagnostics = dict() if opts is None:
self.opts = TensorDiagnosticOptions()
else:
self.opts = opts self.opts = opts
self.diagnostics = dict()
def __getitem__(self, name: str): def __getitem__(self, name: str):
if name not in self.diagnostics: if name not in self.diagnostics:
@ -321,7 +384,7 @@ def attach_diagnostics(
def _test_tensor_diagnostic(): def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2 ** 20) opts = TensorDiagnosticOptions(2 ** 20, 512)
diagnostic = TensorDiagnostic(opts, "foo") diagnostic = TensorDiagnostic(opts, "foo")