mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge changes to diagnostics
This commit is contained in:
commit
feb20ca84d
@ -11,24 +11,21 @@ class TensorDiagnosticOptions(object):
|
|||||||
Options object for tensor diagnostics:
|
Options object for tensor diagnostics:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_limit: the maximum number of bytes per tensor (limits how many copies
|
memory_limit: the maximum number of bytes we store per tensor (limits how many copies
|
||||||
of the tensor we cache).
|
of the tensor we cache).
|
||||||
|
max_eig_dim: the maximum dimension for which we print out eigenvalues
|
||||||
|
(limited for speed reasons).
|
||||||
"""
|
"""
|
||||||
def __init__(self, memory_limit: int,
|
def __init__(self,
|
||||||
print_pos_ratio: bool = True):
|
memory_limit: int = (2 ** 20),
|
||||||
|
max_eig_dim: int = 512):
|
||||||
|
|
||||||
self.memory_limit = memory_limit
|
self.memory_limit = memory_limit
|
||||||
self.print_pos_ratio = print_pos_ratio
|
self.max_eig_dim = max_eig_dim
|
||||||
|
|
||||||
def dim_is_summarized(self, size: int):
|
def dim_is_summarized(self, size: int):
|
||||||
return size > 10 and size != 31
|
return size > 10 and size != 31
|
||||||
|
|
||||||
def stats_types(self):
|
|
||||||
if self.print_pos_ratio:
|
|
||||||
return ["mean-abs", "pos-ratio"]
|
|
||||||
else:
|
|
||||||
return ["mean-abs"]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_stats(x: Tensor, dim: int,
|
def get_tensor_stats(x: Tensor, dim: int,
|
||||||
@ -41,25 +38,34 @@ def get_tensor_stats(x: Tensor, dim: int,
|
|||||||
x: Tensor, tensor to be analyzed
|
x: Tensor, tensor to be analyzed
|
||||||
dim: dimension with 0 <= dim < x.ndim
|
dim: dimension with 0 <= dim < x.ndim
|
||||||
stats_type:
|
stats_type:
|
||||||
"mean-abs" or "abs-value" -> take abs() before summing
|
"abs" -> take abs() before summing
|
||||||
"pos-ratio" -> take (x > 0) before summing
|
"positive" -> take (x > 0) before summing
|
||||||
|
"rms" -> square before summing, we'll take sqrt later
|
||||||
"value -> just sum x itself
|
"value -> just sum x itself
|
||||||
Returns (stats, count)
|
Returns (stats, count)
|
||||||
where stats is a Tensor of shape (x.shape[dim],), and the count
|
where stats is a Tensor of shape (x.shape[dim],), and the count
|
||||||
is an integer saying how many items were counted in each element
|
is an integer saying how many items were counted in each element
|
||||||
of stats.
|
of stats.
|
||||||
"""
|
"""
|
||||||
if stats_type == "mean-abs" or stats_type == "abs-value":
|
count = x.numel() // x.shape[dim]
|
||||||
|
|
||||||
|
if stats_type == "eigs":
|
||||||
|
x = x.transpose(dim, -1)
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
# shape of returned tensor: (s, s) where s is size of dimension `dim` of original x.
|
||||||
|
return torch.matmul(x.transpose(0, 1), x), count
|
||||||
|
elif stats_type == "abs":
|
||||||
x = x.abs()
|
x = x.abs()
|
||||||
elif stats_type == "pos-ratio":
|
elif stats_type == "rms":
|
||||||
|
x = x ** 2
|
||||||
|
elif stats_type == "positive":
|
||||||
x = (x > 0).to(dtype=torch.float)
|
x = (x > 0).to(dtype=torch.float)
|
||||||
else:
|
else:
|
||||||
assert stats_type == "value"
|
assert stats_type == "value"
|
||||||
orig_numel = x.numel()
|
|
||||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
||||||
if len(sum_dims) > 0:
|
if len(sum_dims) > 0:
|
||||||
x = torch.sum(x, dim=sum_dims)
|
x = torch.sum(x, dim=sum_dims)
|
||||||
count = orig_numel // x.numel()
|
|
||||||
x = x.flatten()
|
x = x.flatten()
|
||||||
return x, count
|
return x, count
|
||||||
|
|
||||||
@ -73,24 +79,42 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
|||||||
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim
|
||||||
options: options object
|
options: options object
|
||||||
sizes_same: true if all the tensor sizes are the same on this dimension
|
sizes_same: true if all the tensor sizes are the same on this dimension
|
||||||
stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats
|
stats_type: either "abs" or "positive" or "eigs" or "value,
|
||||||
we accumulate, mean-abs is mean absolute value, "pos-ratio"
|
imdictates the type of stats
|
||||||
is proportion of positive to nonnegative values.
|
we accumulate, abs is mean absolute value, "positive"
|
||||||
|
is proportion of positive to nonnegative values, "eigs"
|
||||||
|
is eigenvalues after doing outer product on this dim, sum
|
||||||
|
over all other dimes.
|
||||||
Returns:
|
Returns:
|
||||||
Diagnostic as a string, either percentiles or the actual values,
|
Diagnostic as a string, either percentiles or the actual values,
|
||||||
see the code.
|
see the code. Will return the empty string if the diagnostics did
|
||||||
|
not make sense to print out for this dimension, e.g. dimension
|
||||||
|
mismatch and stats_type == "eigs"
|
||||||
"""
|
"""
|
||||||
# stats_and_counts is a list of pair (Tensor, int)
|
# stats_and_counts is a list of pair (Tensor, int)
|
||||||
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
|
stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ]
|
||||||
stats = [ x[0] for x in stats_and_counts ]
|
stats = [ x[0] for x in stats_and_counts ]
|
||||||
counts = [ x[1] for x in stats_and_counts ]
|
counts = [ x[1] for x in stats_and_counts ]
|
||||||
if sizes_same:
|
|
||||||
|
if stats_type == "eigs":
|
||||||
|
try:
|
||||||
|
stats = torch.stack(stats).sum(dim=0)
|
||||||
|
except:
|
||||||
|
return ''
|
||||||
|
count = sum(counts)
|
||||||
|
stats = stats / count
|
||||||
|
stats, _ = torch.symeig(stats)
|
||||||
|
stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance
|
||||||
|
elif sizes_same:
|
||||||
stats = torch.stack(stats).sum(dim=0)
|
stats = torch.stack(stats).sum(dim=0)
|
||||||
count = sum(counts)
|
count = sum(counts)
|
||||||
stats = stats / count
|
stats = stats / count
|
||||||
else:
|
else:
|
||||||
stats = [ x[0] / x[1] for x in stats_and_counts ]
|
stats = [ x[0] / x[1] for x in stats_and_counts ]
|
||||||
stats = torch.cat(stats, dim=0)
|
stats = torch.cat(stats, dim=0)
|
||||||
|
if stats_type == 'rms':
|
||||||
|
stats = stats.sqrt()
|
||||||
|
|
||||||
# if `summarize` we print percentiles of the stats; else,
|
# if `summarize` we print percentiles of the stats; else,
|
||||||
# we print out individual elements.
|
# we print out individual elements.
|
||||||
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
|
||||||
@ -117,9 +141,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor],
|
|||||||
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
||||||
options: TensorDiagnosticOptions):
|
options: TensorDiagnosticOptions):
|
||||||
ndim = tensors[0].ndim
|
ndim = tensors[0].ndim
|
||||||
# options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the
|
if ndim > 1:
|
||||||
# normal case.
|
stats_types = ["abs", "positive", "value", "rms"]
|
||||||
stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ]
|
if tensors[0].shape[dim] <= options.max_eig_dim:
|
||||||
|
stats_types.append("eigs")
|
||||||
|
else:
|
||||||
|
stats_types = [ "value", "abs" ]
|
||||||
|
|
||||||
for stats_type in stats_types:
|
for stats_type in stats_types:
|
||||||
sizes = [ x.shape[dim] for x in tensors ]
|
sizes = [ x.shape[dim] for x in tensors ]
|
||||||
@ -127,11 +154,13 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor],
|
|||||||
s = get_diagnostics_for_dim(dim, tensors,
|
s = get_diagnostics_for_dim(dim, tensors,
|
||||||
options, sizes_same,
|
options, sizes_same,
|
||||||
stats_type)
|
stats_type)
|
||||||
|
if s == '':
|
||||||
|
continue
|
||||||
|
|
||||||
min_size = min(sizes)
|
min_size = min(sizes)
|
||||||
max_size = max(sizes)
|
max_size = max(sizes)
|
||||||
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
|
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
|
||||||
# stats_type will be "mean-abs" or "pos-ratio".
|
# stats_type will be "abs" or "positive".
|
||||||
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
|
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
|
||||||
|
|
||||||
|
|
||||||
@ -181,15 +210,22 @@ class TensorDiagnostic(object):
|
|||||||
# ensure there is at least one dim.
|
# ensure there is at least one dim.
|
||||||
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
|
self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ]
|
||||||
|
|
||||||
|
try:
|
||||||
|
device = torch.device('cuda')
|
||||||
|
torch.ones(1, 1, device)
|
||||||
|
except:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
ndim = self.saved_tensors[0].ndim
|
ndim = self.saved_tensors[0].ndim
|
||||||
|
tensors = [x.to(device) for x in self.saved_tensors]
|
||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
print_diagnostics_for_dim(self.name, dim,
|
print_diagnostics_for_dim(self.name, dim,
|
||||||
self.saved_tensors,
|
tensors,
|
||||||
self.opts)
|
self.opts)
|
||||||
|
|
||||||
|
|
||||||
class ModelDiagnostic(object):
|
class ModelDiagnostic(object):
|
||||||
def __init__(self, opts: TensorDiagnosticOptions):
|
def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()):
|
||||||
self.diagnostics = dict()
|
self.diagnostics = dict()
|
||||||
self.opts = opts
|
self.opts = opts
|
||||||
|
|
||||||
@ -252,7 +288,7 @@ def attach_diagnostics(model: nn.Module,
|
|||||||
|
|
||||||
|
|
||||||
def _test_tensor_diagnostic():
|
def _test_tensor_diagnostic():
|
||||||
opts = TensorDiagnosticOptions(2**20, True)
|
opts = TensorDiagnosticOptions(2**20, 512)
|
||||||
|
|
||||||
diagnostic = TensorDiagnostic(opts, "foo")
|
diagnostic = TensorDiagnostic(opts, "foo")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user