mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Improvements to diagnostics (RE those with 1 dim
This commit is contained in:
parent
63d8d935d4
commit
2ff520c800
@ -31,31 +31,33 @@ class TensorDiagnosticOptions(object):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_sum_abs_stats(x: Tensor, dim: int,
|
def get_tensor_stats(x: Tensor, dim: int,
|
||||||
stats_type: str) -> Tuple[Tensor, int]:
|
stats_type: str) -> Tuple[Tensor, int]:
|
||||||
"""
|
"""
|
||||||
Returns the sum-of-absolute-value of this Tensor, for each
|
Returns the specified transformation of the Tensor (either x or x.abs()
|
||||||
index into the specified axis/dim of the tensor.
|
or (x > 0), summed over all but the index `dim`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Tensor, tensor to be analyzed
|
x: Tensor, tensor to be analyzed
|
||||||
dim: dimension with 0 <= dim < x.ndim
|
dim: dimension with 0 <= dim < x.ndim
|
||||||
stats_type: either "mean-abs" in which case the stats represent the
|
stats_type:
|
||||||
mean absolute value, or "pos-ratio" in which case the
|
"mean-abs" or "abs-value" -> take abs() before summing
|
||||||
stats represent the proportion of positive values (actually:
|
"pos-ratio" -> take (x > 0) before summing
|
||||||
the tensor is count of positive values, count is the count of
|
"value -> just sum x itself
|
||||||
all values).
|
Returns (stats, count)
|
||||||
Returns (sum_abs, count)
|
where stats is a Tensor of shape (x.shape[dim],), and the count
|
||||||
where sum_abs is a Tensor of shape (x.shape[dim],), and the count
|
|
||||||
is an integer saying how many items were counted in each element
|
is an integer saying how many items were counted in each element
|
||||||
of sum_abs.
|
of stats.
|
||||||
"""
|
"""
|
||||||
if stats_type == "mean-abs":
|
if stats_type == "mean-abs" or stats_type == "abs-value":
|
||||||
x = x.abs()
|
x = x.abs()
|
||||||
else:
|
elif stats_type == "pos-ratio":
|
||||||
assert stats_type == "pos-ratio"
|
|
||||||
x = (x > 0).to(dtype=torch.float)
|
x = (x > 0).to(dtype=torch.float)
|
||||||
|
else:
|
||||||
|
assert stats_type == "value"
|
||||||
orig_numel = x.numel()
|
orig_numel = x.numel()
|
||||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
||||||
|
if len(sum_dims) > 0:
|
||||||
x = torch.sum(x, dim=sum_dims)
|
x = torch.sum(x, dim=sum_dims)
|
||||||
count = orig_numel // x.numel()
|
count = orig_numel // x.numel()
|
||||||
x = x.flatten()
|
x = x.flatten()
|
||||||
@ -79,7 +81,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
|||||||
see the code.
|
see the code.
|
||||||
"""
|
"""
|
||||||
# stats_and_counts is a list of pair (Tensor, int)
|
# stats_and_counts is a list of pair (Tensor, int)
|
||||||
stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ]
|
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
|
||||||
stats = [ x[0] for x in stats_and_counts ]
|
stats = [ x[0] for x in stats_and_counts ]
|
||||||
counts = [ x[1] for x in stats_and_counts ]
|
counts = [ x[1] for x in stats_and_counts ]
|
||||||
if sizes_same:
|
if sizes_same:
|
||||||
@ -114,9 +116,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
|||||||
|
|
||||||
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
||||||
options: TensorDiagnosticOptions):
|
options: TensorDiagnosticOptions):
|
||||||
|
ndim = tensors[0].ndim
|
||||||
|
# options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the
|
||||||
|
# normal case.
|
||||||
|
stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ]
|
||||||
|
|
||||||
for stats_type in options.stats_types():
|
for stats_type in stats_types:
|
||||||
# stats_type will be "mean-abs" or "pos-ratio".
|
|
||||||
sizes = [ x.shape[dim] for x in tensors ]
|
sizes = [ x.shape[dim] for x in tensors ]
|
||||||
sizes_same = all([ x == sizes[0] for x in sizes ])
|
sizes_same = all([ x == sizes[0] for x in sizes ])
|
||||||
s = get_diagnostics_for_dim(dim, tensors,
|
s = get_diagnostics_for_dim(dim, tensors,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user