mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Extensions to diagnostics code
This commit is contained in:
parent
2ff520c800
commit
d074cf73c6
@ -25,7 +25,7 @@ class TensorDiagnosticOptions(object):
|
||||
|
||||
def stats_types(self):
|
||||
if self.print_pos_ratio:
|
||||
return ["mean-abs", "pos-ratio"]
|
||||
return ["mean-abs", "pos-ratio", "value"]
|
||||
else:
|
||||
return ["mean-abs"]
|
||||
|
||||
@ -49,17 +49,23 @@ def get_tensor_stats(x: Tensor, dim: int,
|
||||
is an integer saying how many items were counted in each element
|
||||
of stats.
|
||||
"""
|
||||
if stats_type == "mean-abs" or stats_type == "abs-value":
|
||||
count = x.numel() // x.shape[dim]
|
||||
|
||||
if stats_type == "eigs":
|
||||
x = x.transpose(dim, -1)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
# shape of returned tensor: (s, s) where s is size of dimension `dim` of original x.
|
||||
return torch.matmul(x.transpose(0, 1), x), count
|
||||
elif stats_type == "mean-abs" or stats_type == "abs-value":
|
||||
x = x.abs()
|
||||
elif stats_type == "pos-ratio":
|
||||
x = (x > 0).to(dtype=torch.float)
|
||||
else:
|
||||
assert stats_type == "value"
|
||||
orig_numel = x.numel()
|
||||
|
||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
||||
if len(sum_dims) > 0:
|
||||
x = torch.sum(x, dim=sum_dims)
|
||||
count = orig_numel // x.numel()
|
||||
x = x.flatten()
|
||||
return x, count
|
||||
|
||||
@ -73,18 +79,35 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||
options: options object
|
||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
||||
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats
|
||||
stats_type: either "mean-abs" or "pos-ratio" or "eigs" or "value,
|
||||
imdictates the type of stats
|
||||
we accumulate, mean-abs is mean absolute value, "pos-ratio"
|
||||
is proportion of positive to nonnegative values.
|
||||
is proportion of positive to nonnegative values, "eigs"
|
||||
is eigenvalues after doing outer product on this dim, sum
|
||||
over all other dimes.
|
||||
Returns:
|
||||
Diagnostic as a string, either percentiles or the actual values,
|
||||
see the code.
|
||||
see the code. Will return the empty string if the diagnostics did
|
||||
not make sense to print out for this dimension, e.g. dimension
|
||||
mismatch and stats_type == "eigs"
|
||||
"""
|
||||
# stats_and_counts is a list of pair (Tensor, int)
|
||||
if tensors[0].shape[dim] > 512 and stats_type == 'eigs':
|
||||
return '' # won't produce eigs stats if dim too large.
|
||||
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
|
||||
stats = [ x[0] for x in stats_and_counts ]
|
||||
counts = [ x[1] for x in stats_and_counts ]
|
||||
if sizes_same:
|
||||
|
||||
if stats_type == 'eigs':
|
||||
try:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
except:
|
||||
return ''
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
stats, _ = torch.symeig(stats)
|
||||
stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif sizes_same:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
@ -121,12 +144,16 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
||||
# normal case.
|
||||
stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ]
|
||||
|
||||
stats_types = stats_types + ["eigs"]
|
||||
|
||||
for stats_type in stats_types:
|
||||
sizes = [ x.shape[dim] for x in tensors ]
|
||||
sizes_same = all([ x == sizes[0] for x in sizes ])
|
||||
s = get_diagnostics_for_dim(dim, tensors,
|
||||
options, sizes_same,
|
||||
stats_type)
|
||||
if s == '':
|
||||
continue
|
||||
|
||||
min_size = min(sizes)
|
||||
max_size = max(sizes)
|
||||
@ -181,10 +208,17 @@ class TensorDiagnostic(object):
|
||||
# ensure there is at least one dim.
|
||||
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
|
||||
|
||||
try:
|
||||
device = torch.device('cuda')
|
||||
torch.ones(1, 1, device)
|
||||
except:
|
||||
device = torch.device('cpu')
|
||||
|
||||
ndim = self.saved_tensors[0].ndim
|
||||
tensors = [x.to(device) for x in self.saved_tensors]
|
||||
for dim in range(ndim):
|
||||
print_diagnostics_for_dim(self.name, dim,
|
||||
self.saved_tensors,
|
||||
tensors,
|
||||
self.opts)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user