Improvements to diagnostics (RE those with 1 dim

This commit is contained in:
Daniel Povey 2022-02-28 12:15:56 +08:00
parent 63d8d935d4
commit 2ff520c800

View File

@ -31,32 +31,34 @@ class TensorDiagnosticOptions(object):
def get_sum_abs_stats(x: Tensor, dim: int, def get_tensor_stats(x: Tensor, dim: int,
stats_type: str) -> Tuple[Tensor, int]: stats_type: str) -> Tuple[Tensor, int]:
""" """
Returns the sum-of-absolute-value of this Tensor, for each Returns the specified transformation of the Tensor (either x or x.abs()
index into the specified axis/dim of the tensor. or (x > 0), summed over all but the index `dim`.
Args: Args:
x: Tensor, tensor to be analyzed x: Tensor, tensor to be analyzed
dim: dimension with 0 <= dim < x.ndim dim: dimension with 0 <= dim < x.ndim
stats_type: either "mean-abs" in which case the stats represent the stats_type:
mean absolute value, or "pos-ratio" in which case the "mean-abs" or "abs-value" -> take abs() before summing
stats represent the proportion of positive values (actually: "pos-ratio" -> take (x > 0) before summing
the tensor is count of positive values, count is the count of "value -> just sum x itself
all values). Returns (stats, count)
Returns (sum_abs, count) where stats is a Tensor of shape (x.shape[dim],), and the count
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
is an integer saying how many items were counted in each element is an integer saying how many items were counted in each element
of sum_abs. of stats.
""" """
if stats_type == "mean-abs": if stats_type == "mean-abs" or stats_type == "abs-value":
x = x.abs() x = x.abs()
else: elif stats_type == "pos-ratio":
assert stats_type == "pos-ratio"
x = (x > 0).to(dtype=torch.float) x = (x > 0).to(dtype=torch.float)
else:
assert stats_type == "value"
orig_numel = x.numel() orig_numel = x.numel()
sum_dims = [ d for d in range(x.ndim) if d != dim ] sum_dims = [ d for d in range(x.ndim) if d != dim ]
x = torch.sum(x, dim=sum_dims) if len(sum_dims) > 0:
x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel() count = orig_numel // x.numel()
x = x.flatten() x = x.flatten()
return x, count return x, count
@ -79,7 +81,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
see the code. see the code.
""" """
# stats_and_counts is a list of pair (Tensor, int) # stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
stats = [ x[0] for x in stats_and_counts ] stats = [ x[0] for x in stats_and_counts ]
counts = [ x[1] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ]
if sizes_same: if sizes_same:
@ -114,9 +116,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
options: TensorDiagnosticOptions): options: TensorDiagnosticOptions):
ndim = tensors[0].ndim
# options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the
# normal case.
stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ]
for stats_type in options.stats_types(): for stats_type in stats_types:
# stats_type will be "mean-abs" or "pos-ratio".
sizes = [ x.shape[dim] for x in tensors ] sizes = [ x.shape[dim] for x in tensors ]
sizes_same = all([ x == sizes[0] for x in sizes ]) sizes_same = all([ x == sizes[0] for x in sizes ])
s = get_diagnostics_for_dim(dim, tensors, s = get_diagnostics_for_dim(dim, tensors,