Update diagnostics

This commit is contained in:
Daniel Povey 2022-03-10 10:28:48 +08:00
parent d074cf73c6
commit 1e5455ba29

View File

@ -11,24 +11,21 @@ class TensorDiagnosticOptions(object):
Options object for tensor diagnostics:
Args:
memory_limit: the maximum number of bytes per tensor (limits how many copies
memory_limit: the maximum number of bytes we store per tensor (limits how many copies
of the tensor we cache).
max_eig_dim: the maximum dimension for which we print out eigenvalues
(limited for speed reasons).
"""
def __init__(self, memory_limit: int,
print_pos_ratio: bool = True):
def __init__(self,
memory_limit: int = (2 ** 20),
max_eig_dim: int = 512):
self.memory_limit = memory_limit
self.print_pos_ratio = print_pos_ratio
self.max_eig_dim = max_eig_dim
def dim_is_summarized(self, size: int):
return size > 10 and size != 31
def stats_types(self):
if self.print_pos_ratio:
return ["mean-abs", "pos-ratio", "value"]
else:
return ["mean-abs"]
def get_tensor_stats(x: Tensor, dim: int,
@ -41,8 +38,9 @@ def get_tensor_stats(x: Tensor, dim: int,
x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim
stats_type:
"mean-abs" or "abs-value" -> take abs() before summing
"pos-ratio" -> take (x > 0) before summing
"abs" -> take abs() before summing
"positive" -> take (x > 0) before summing
"rms" -> square before summing, we'll take sqrt later
"value -> just sum x itself
Returns (stats, count)
where stats is a Tensor of shape (x.shape[dim],), and the count
@ -56,9 +54,11 @@ def get_tensor_stats(x: Tensor, dim: int,
x = x.reshape(-1, x.shape[-1])
# shape of returned tensor: (s, s) where s is size of dimension `dim` of original x.
return torch.matmul(x.transpose(0, 1), x), count
elif stats_type == "mean-abs" or stats_type == "abs-value":
elif stats_type == "abs":
x = x.abs()
elif stats_type == "pos-ratio":
elif stats_type == "rms":
x = x ** 2
elif stats_type == "positive":
x = (x > 0).to(dtype=torch.float)
else:
assert stats_type == "value"
@ -79,9 +79,9 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
options: options object
sizes_same: true if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio" or "eigs" or "value,
stats_type: either "abs" or "positive" or "eigs" or "value,
imdictates the type of stats
we accumulate, mean-abs is mean absolute value, "pos-ratio"
we accumulate, abs is mean absolute value, "positive"
is proportion of positive to nonnegative values, "eigs"
is eigenvalues after doing outer product on this dim, sum
over all other dimes.
@ -92,13 +92,11 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
mismatch and stats_type == "eigs"
"""
# stats_and_counts is a list of pair (Tensor, int)
if tensors[0].shape[dim] > 512 and stats_type == 'eigs':
return '' # won't produce eigs stats if dim too large.
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
stats = [ x[0] for x in stats_and_counts ]
counts = [ x[1] for x in stats_and_counts ]
if stats_type == 'eigs':
if stats_type == "eigs":
try:
stats = torch.stack(stats).sum(dim=0)
except:
@ -114,6 +112,9 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
else:
stats = [ x[0] / x[1] for x in stats_and_counts ]
stats = torch.cat(stats, dim=0)
if stats_type == 'rms':
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
@ -140,11 +141,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions):
ndim = tensors[0].ndim
# options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the
# normal case.
stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ]
stats_types = stats_types + ["eigs"]
if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = [ "value", "abs" ]
for stats_type in stats_types:
sizes = [ x.shape[dim] for x in tensors ]
@ -158,7 +160,7 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "mean-abs" or "pos-ratio".
# stats_type will be "abs" or "positive".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
@ -223,7 +225,7 @@ class TensorDiagnostic(object):
class ModelDiagnostic(object):
def __init__(self, opts: TensorDiagnosticOptions):
def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()):
self.diagnostics = dict()
self.opts = opts
@ -286,7 +288,7 @@ def attach_diagnostics(model: nn.Module,
def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2**20, True)
opts = TensorDiagnosticOptions(2**20, 512)
diagnostic = TensorDiagnostic(opts, "foo")