Extensions to diagnostics code

This commit is contained in:
Daniel Povey 2022-03-09 20:37:20 +08:00
parent 2ff520c800
commit d074cf73c6

View File

@ -25,7 +25,7 @@ class TensorDiagnosticOptions(object):
def stats_types(self):
if self.print_pos_ratio:
return ["mean-abs", "pos-ratio"]
return ["mean-abs", "pos-ratio", "value"]
else:
return ["mean-abs"]
@ -49,17 +49,23 @@ def get_tensor_stats(x: Tensor, dim: int,
is an integer saying how many items were counted in each element
of stats.
"""
if stats_type == "mean-abs" or stats_type == "abs-value":
count = x.numel() // x.shape[dim]
if stats_type == "eigs":
x = x.transpose(dim, -1)
x = x.reshape(-1, x.shape[-1])
# shape of returned tensor: (s, s) where s is size of dimension `dim` of original x.
return torch.matmul(x.transpose(0, 1), x), count
elif stats_type == "mean-abs" or stats_type == "abs-value":
x = x.abs()
elif stats_type == "pos-ratio":
x = (x > 0).to(dtype=torch.float)
else:
assert stats_type == "value"
orig_numel = x.numel()
sum_dims = [ d for d in range(x.ndim) if d != dim ]
if len(sum_dims) > 0:
x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten()
return x, count
@ -73,18 +79,35 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
options: options object
sizes_same: true if all the tensor sizes are the same on this dimension
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats
stats_type: either "mean-abs" or "pos-ratio" or "eigs" or "value,
imdictates the type of stats
we accumulate, mean-abs is mean absolute value, "pos-ratio"
is proportion of positive to nonnegative values.
is proportion of positive to nonnegative values, "eigs"
is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code.
see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs"
"""
# stats_and_counts is a list of pair (Tensor, int)
if tensors[0].shape[dim] > 512 and stats_type == 'eigs':
return '' # won't produce eigs stats if dim too large.
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
stats = [ x[0] for x in stats_and_counts ]
counts = [ x[1] for x in stats_and_counts ]
if sizes_same:
if stats_type == 'eigs':
try:
stats = torch.stack(stats).sum(dim=0)
except:
return ''
count = sum(counts)
stats = stats / count
stats, _ = torch.symeig(stats)
stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
@ -121,12 +144,16 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
# normal case.
stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ]
stats_types = stats_types + ["eigs"]
for stats_type in stats_types:
sizes = [ x.shape[dim] for x in tensors ]
sizes_same = all([ x == sizes[0] for x in sizes ])
s = get_diagnostics_for_dim(dim, tensors,
options, sizes_same,
stats_type)
if s == '':
continue
min_size = min(sizes)
max_size = max(sizes)
@ -181,10 +208,17 @@ class TensorDiagnostic(object):
# ensure there is at least one dim.
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
try:
device = torch.device('cuda')
torch.ones(1, 1, device)
except:
device = torch.device('cpu')
ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim,
self.saved_tensors,
tensors,
self.opts)