Improvements to diagnostics (RE those with 1 dim

This commit is contained in:
Daniel Povey 2022-02-28 12:15:56 +08:00
parent 63d8d935d4
commit 2ff520c800

View File

@ -31,32 +31,34 @@ class TensorDiagnosticOptions(object):
def get_sum_abs_stats(x: Tensor, dim: int,
def get_tensor_stats(x: Tensor, dim: int,
stats_type: str) -> Tuple[Tensor, int]:
"""
Returns the sum-of-absolute-value of this Tensor, for each
index into the specified axis/dim of the tensor.
Returns the specified transformation of the Tensor (either x or x.abs()
or (x > 0), summed over all but the index `dim`.
Args:
x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim
stats_type: either "mean-abs" in which case the stats represent the
mean absolute value, or "pos-ratio" in which case the
stats represent the proportion of positive values (actually:
the tensor is count of positive values, count is the count of
all values).
Returns (sum_abs, count)
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
stats_type:
"mean-abs" or "abs-value" -> take abs() before summing
"pos-ratio" -> take (x > 0) before summing
"value -> just sum x itself
Returns (stats, count)
where stats is a Tensor of shape (x.shape[dim],), and the count
is an integer saying how many items were counted in each element
of sum_abs.
of stats.
"""
if stats_type == "mean-abs":
if stats_type == "mean-abs" or stats_type == "abs-value":
x = x.abs()
else:
assert stats_type == "pos-ratio"
elif stats_type == "pos-ratio":
x = (x > 0).to(dtype=torch.float)
else:
assert stats_type == "value"
orig_numel = x.numel()
sum_dims = [ d for d in range(x.ndim) if d != dim ]
x = torch.sum(x, dim=sum_dims)
if len(sum_dims) > 0:
x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten()
return x, count
@ -79,7 +81,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
see the code.
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ]
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
stats = [ x[0] for x in stats_and_counts ]
counts = [ x[1] for x in stats_and_counts ]
if sizes_same:
@ -114,9 +116,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions):
ndim = tensors[0].ndim
# options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the
# normal case.
stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ]
for stats_type in options.stats_types():
# stats_type will be "mean-abs" or "pos-ratio".
for stats_type in stats_types:
sizes = [ x.shape[dim] for x in tensors ]
sizes_same = all([ x == sizes[0] for x in sizes ])
s = get_diagnostics_for_dim(dim, tensors,