mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
update diagnostics.py
This commit is contained in:
parent
a7643301ec
commit
fb5d677c7f
@ -1,5 +1,6 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey
|
||||
# Zengwei Yao)
|
||||
# Zengwei Yao
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -28,51 +29,67 @@ class TensorDiagnosticOptions(object):
|
||||
|
||||
Args:
|
||||
memory_limit:
|
||||
The maximum number of bytes per tensor (limits how many copies
|
||||
of the tensor we cache).
|
||||
The maximum number of bytes per tensor
|
||||
(limits how many copies of the tensor we cache).
|
||||
max_eig_dim:
|
||||
The maximum dimension for which we print out eigenvalues
|
||||
(limited for speed reasons).
|
||||
"""
|
||||
|
||||
def __init__(self, memory_limit: int):
|
||||
def __init__(
|
||||
self,
|
||||
memory_limit: int = (2 ** 20),
|
||||
max_eig_dim: int = 512
|
||||
):
|
||||
self.memory_limit = memory_limit
|
||||
self.max_eig_dim = max_eig_dim
|
||||
|
||||
def dim_is_summarized(self, size: int):
|
||||
return size > 10 and size != 31
|
||||
|
||||
|
||||
def get_sum_abs_stats(
|
||||
def get_tensor_stats(
|
||||
x: Tensor, dim: int, stats_type: str
|
||||
) -> Tuple[Tensor, int]:
|
||||
"""Returns the sum-of-absolute-value of this Tensor, for each index into
|
||||
the specified axis/dim of the tensor.
|
||||
"""
|
||||
Returns the specified transformation of the Tensor (either x or x.abs()
|
||||
or (x > 0), summed over all but the index `dim`.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Tensor, tensor to be analyzed
|
||||
dim:
|
||||
Dimension with 0 <= dim < x.ndim
|
||||
x: Tensor, tensor to be analyzed
|
||||
dim: dimension with 0 <= dim < x.ndim
|
||||
stats_type:
|
||||
Either "mean-abs" in which case the stats represent the mean absolute
|
||||
value, or "pos-ratio" in which case the stats represent the proportion
|
||||
of positive values (actually: the tensor is count of positive values,
|
||||
count is the count of all values).
|
||||
|
||||
Returns:
|
||||
(sum_abs, count) where sum_abs is a Tensor of shape (x.shape[dim],),
|
||||
and the count is an integer saying how many items were counted in
|
||||
each element of sum_abs.
|
||||
"abs" -> take abs() before summing
|
||||
"positive" -> take (x > 0) before summing
|
||||
"rms" -> square before summing, we'll take sqrt later
|
||||
"value -> just sum x itself
|
||||
Returns (stats, count)
|
||||
where stats is a Tensor of shape (x.shape[dim],), and the count
|
||||
is an integer saying how many items were counted in each element
|
||||
of stats.
|
||||
"""
|
||||
if stats_type == "mean-abs":
|
||||
|
||||
count = x.numel() // x.shape[dim]
|
||||
|
||||
if stats_type == "eigs":
|
||||
x = x.transpose(dim, -1)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
# shape of returned tensor: (s, s),
|
||||
# where s is size of dimension `dim` of original x.
|
||||
return torch.matmul(x.transpose(0, 1), x), count
|
||||
elif stats_type == "abs":
|
||||
x = x.abs()
|
||||
else:
|
||||
assert stats_type == "pos-ratio"
|
||||
elif stats_type == "rms":
|
||||
x = x ** 2
|
||||
elif stats_type == "positive":
|
||||
x = (x > 0).to(dtype=torch.float)
|
||||
else:
|
||||
assert stats_type == "value"
|
||||
|
||||
orig_numel = x.numel()
|
||||
sum_dims = [d for d in range(x.ndim) if d != dim]
|
||||
x = torch.sum(x, dim=sum_dims)
|
||||
count = orig_numel // x.numel()
|
||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
||||
if len(sum_dims) > 0:
|
||||
x = torch.sum(x, dim=sum_dims)
|
||||
x = x.flatten()
|
||||
|
||||
return x, count
|
||||
|
||||
|
||||
@ -83,43 +100,55 @@ def get_diagnostics_for_dim(
|
||||
sizes_same: bool,
|
||||
stats_type: str,
|
||||
) -> str:
|
||||
"""This function gets diagnostics for a dimension of a module.
|
||||
|
||||
"""
|
||||
This function gets diagnostics for a dimension of a module.
|
||||
Args:
|
||||
dim:
|
||||
The dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||
tensors:
|
||||
List of cached tensors to get the stats
|
||||
options:
|
||||
Options object
|
||||
sizes_same:
|
||||
True if all the tensor sizes are the same on this dimension
|
||||
stats_type: either "mean-abs" or "pos-ratio", dictates the type of
|
||||
stats we accumulate, mean-abs is mean absolute value, "pos-ratio" is
|
||||
proportion of positive to nonnegative values.
|
||||
|
||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||
options: options object
|
||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
||||
stats_type: either "abs" or "positive" or "eigs" or "value",
|
||||
imdictates the type of stats
|
||||
we accumulate, abs is mean absolute value, "positive"
|
||||
is proportion of positive to nonnegative values, "eigs"
|
||||
is eigenvalues after doing outer product on this dim, sum
|
||||
over all other dimes.
|
||||
Returns:
|
||||
Diagnostic as a string, either percentiles or the actual values,
|
||||
see the code.
|
||||
Diagnostic as a string, either percentiles or the actual values,
|
||||
see the code. Will return the empty string if the diagnostics did
|
||||
not make sense to print out for this dimension, e.g. dimension
|
||||
mismatch and stats_type == "eigs"
|
||||
"""
|
||||
|
||||
# stats_and_counts is a list of pair (Tensor, int)
|
||||
stats_and_counts = [get_sum_abs_stats(x, dim, stats_type) for x in tensors]
|
||||
stats = [x[0] for x in stats_and_counts]
|
||||
counts = [x[1] for x in stats_and_counts]
|
||||
if sizes_same:
|
||||
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
|
||||
stats = [ x[0] for x in stats_and_counts ]
|
||||
counts = [ x[1] for x in stats_and_counts ]
|
||||
|
||||
if stats_type == "eigs":
|
||||
try:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
except:
|
||||
return ''
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
stats, _ = torch.symeig(stats)
|
||||
stats = stats.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif sizes_same:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
else:
|
||||
stats = [x[0] / x[1] for x in stats_and_counts]
|
||||
stats = [ x[0] / x[1] for x in stats_and_counts ]
|
||||
stats = torch.cat(stats, dim=0)
|
||||
if stats_type == 'rms':
|
||||
stats = stats.sqrt()
|
||||
|
||||
# If `summarize` we print percentiles of the stats;
|
||||
# else, we print out individual elements.
|
||||
# if `summarize` we print percentiles of the stats; else,
|
||||
# we print out individual elements.
|
||||
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
||||
if summarize:
|
||||
# Print out percentiles.
|
||||
# print out percentiles.
|
||||
stats = stats.sort()[0]
|
||||
num_percentiles = 10
|
||||
size = stats.numel()
|
||||
@ -127,14 +156,27 @@ def get_diagnostics_for_dim(
|
||||
for i in range(num_percentiles + 1):
|
||||
index = (i * (size - 1)) // num_percentiles
|
||||
percentiles.append(stats[index].item())
|
||||
percentiles = ["%.2g" % x for x in percentiles]
|
||||
percentiles = " ".join(percentiles)
|
||||
return f"percentiles: [{percentiles}]"
|
||||
percentiles = [ '%.2g' % x for x in percentiles ]
|
||||
percentiles = ' '.join(percentiles)
|
||||
ans = f'percentiles: [{percentiles}]'
|
||||
else:
|
||||
stats = stats.tolist()
|
||||
stats = ["%.2g" % x for x in stats]
|
||||
stats = "[" + " ".join(stats) + "]"
|
||||
return stats
|
||||
ans = stats.tolist()
|
||||
ans = [ '%.2g' % x for x in ans ]
|
||||
ans = '[' + ' '.join(ans) + ']'
|
||||
if stats_type == "value":
|
||||
# This norm is useful because it is strictly less than the largest
|
||||
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
||||
# speaking in an approximate way, how much of that largest eigenvalue
|
||||
# can be attributed to the mean of the distribution.
|
||||
norm = (stats ** 2).sum().sqrt().item()
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}'
|
||||
else:
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
ans += f', mean={mean:.2g}, rms={rms:.2g}'
|
||||
return ans
|
||||
|
||||
|
||||
def print_diagnostics_for_dim(
|
||||
@ -153,17 +195,27 @@ def print_diagnostics_for_dim(
|
||||
Options object.
|
||||
"""
|
||||
|
||||
for stats_type in ["mean-abs", "pos-ratio"]:
|
||||
# stats_type will be "mean-abs" or "pos-ratio".
|
||||
sizes = [x.shape[dim] for x in tensors]
|
||||
sizes_same = all([x == sizes[0] for x in sizes])
|
||||
s = get_diagnostics_for_dim(
|
||||
dim, tensors, options, sizes_same, stats_type
|
||||
)
|
||||
ndim = tensors[0].ndim
|
||||
if ndim > 1:
|
||||
stats_types = ["abs", "positive", "value", "rms"]
|
||||
if tensors[0].shape[dim] <= options.max_eig_dim:
|
||||
stats_types.append("eigs")
|
||||
else:
|
||||
stats_types = [ "value", "abs" ]
|
||||
|
||||
for stats_type in stats_types:
|
||||
sizes = [ x.shape[dim] for x in tensors ]
|
||||
sizes_same = all([ x == sizes[0] for x in sizes ])
|
||||
s = get_diagnostics_for_dim(dim, tensors,
|
||||
options, sizes_same,
|
||||
stats_type)
|
||||
if s == '':
|
||||
continue
|
||||
|
||||
min_size = min(sizes)
|
||||
max_size = max(sizes)
|
||||
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
|
||||
# stats_type will be "abs" or "positive".
|
||||
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
|
||||
|
||||
|
||||
@ -225,10 +277,17 @@ class TensorDiagnostic(object):
|
||||
# Ensure there is at least one dim.
|
||||
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors]
|
||||
|
||||
try:
|
||||
device = torch.device("cuda")
|
||||
torch.ones(1, 1, device)
|
||||
except:
|
||||
device = torch.device("cpu")
|
||||
|
||||
ndim = self.saved_tensors[0].ndim
|
||||
tensors = [x.to(device) for x in self.saved_tensors]
|
||||
for dim in range(ndim):
|
||||
print_diagnostics_for_dim(
|
||||
self.name, dim, self.saved_tensors, self.opts
|
||||
self.name, dim, tensors, self.opts
|
||||
)
|
||||
|
||||
|
||||
@ -240,7 +299,7 @@ class ModelDiagnostic(object):
|
||||
Options object.
|
||||
"""
|
||||
|
||||
def __init__(self, opts: TensorDiagnosticOptions):
|
||||
def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()):
|
||||
# In this dictionary, the keys are tensors names and the values
|
||||
# are corresponding TensorDiagnostic objects.
|
||||
self.diagnostics = dict()
|
||||
@ -321,7 +380,7 @@ def attach_diagnostics(
|
||||
|
||||
|
||||
def _test_tensor_diagnostic():
|
||||
opts = TensorDiagnosticOptions(2 ** 20)
|
||||
opts = TensorDiagnosticOptions(2**20, 512)
|
||||
|
||||
diagnostic = TensorDiagnostic(opts, "foo")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user