mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove comparison diagnostics, which were not that useful.
This commit is contained in:
parent
2e93e5d3b7
commit
8d1021d131
@ -287,122 +287,6 @@ class TensorDiagnostic(object):
|
||||
)
|
||||
|
||||
|
||||
def print_joint_diagnostics(self, other: 'TensorDiagnostic'):
|
||||
"""
|
||||
Prints diagnostics that relate to correlations between the 'basic' diagnostics
|
||||
printed in print_diagnostics().
|
||||
"""
|
||||
combined_name = _summarize_two_names(self.name, other.name)
|
||||
# e.g. combined_name == 'foo.{param_value,param_grad}' or just 'foo.param_value' if self.name == other.name.
|
||||
if self.stats is None:
|
||||
return
|
||||
for dim, this_dim_stats in enumerate(self.stats):
|
||||
try:
|
||||
other_dim_stats = other.stats[dim]
|
||||
except (TypeError, IndexError):
|
||||
print(f"Continuing, dim={dim}, (0)")
|
||||
continue
|
||||
|
||||
output_list = []
|
||||
for stats_type, stats_list in this_dim_stats.items():
|
||||
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
||||
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
|
||||
# shape of the stats), or None
|
||||
if stats_list is None:
|
||||
continue
|
||||
# work out `size_str`, will be used to print out data later.. this is the
|
||||
# same for all `stats_type` values
|
||||
sizes = [x.tensor.shape[0] for x in stats_list]
|
||||
size_str = (
|
||||
f"{sizes[0]}"
|
||||
if len(sizes) == 1
|
||||
else f"{min(sizes)}..{max(sizes)}"
|
||||
)
|
||||
|
||||
if len(stats_list) == 1:
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
else:
|
||||
stats = torch.cat(
|
||||
[x.tensor / x.count for x in stats_list], dim=0
|
||||
)
|
||||
|
||||
other_stats_list = other_dim_stats[stats_type]
|
||||
for other_stats_type, other_stats_list in other_dim_stats.items():
|
||||
# avoid redundantly comparing a,b and b,a
|
||||
if (other_stats_type > stats_type or other_stats_list is None or
|
||||
len(other_stats_list) == 0 or stats_list is other_stats_list):
|
||||
continue
|
||||
if len(stats_list) == 1:
|
||||
other_stats = other_stats_list[0].tensor / other_stats_list[0].count
|
||||
size = stats.shape[0]
|
||||
else:
|
||||
other_stats = torch.cat(
|
||||
[x.tensor / x.count for x in other_stats_list], dim=0
|
||||
)
|
||||
if other_stats.shape != stats.shape:
|
||||
# e.g. stats_type == "eigs" and other_stats_type !=
|
||||
# "eigs" or the other way around
|
||||
continue
|
||||
|
||||
|
||||
if stats.ndim == 2:
|
||||
# Matrices, for purposes of measuring eigenvalues. Just compute a dot-product-related
|
||||
# measure of correlation of eigenvalues.
|
||||
def norm_diag_mean(x):
|
||||
return x - torch.eye(x.shape[0], dtype=x.dtype, device=x.device) * x.diag().mean()
|
||||
stats = norm_diag_mean(stats)
|
||||
other_stats = norm_diag_mean(other_stats)
|
||||
correlation = ((stats * other_stats).sum() /
|
||||
((stats**2).sum() * (other_stats**2).sum() + 1.0e-20).sqrt())
|
||||
else:
|
||||
# ndim == 1
|
||||
# Use a rank-based measure of correlation
|
||||
(_, indices1) = stats.sort()
|
||||
(_, indices2) = other_stats.sort()
|
||||
n = stats.numel()
|
||||
rank1 = ((indices1 + 0.5) / n) - 0.5
|
||||
rank2 = ((indices2 + 0.5) / n) - 0.5
|
||||
correlation = (rank1 * rank2).sum() / (rank1 * rank1).sum()
|
||||
output_list.append(f'{stats_type},{other_stats_type}={correlation:.3f}')
|
||||
if len(output_list) == 0:
|
||||
continue
|
||||
|
||||
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
||||
output = f"module={combined_name}{maybe_class_name} dim={dim} size={size_str}: " + ' '.join(output_list)
|
||||
print(output)
|
||||
|
||||
|
||||
|
||||
def _summarize_two_names(a: str, b:str, separator: str = ',') -> str:
|
||||
"""
|
||||
Given 'foo.ab' and 'foo.xyz', returns 'foo.{ab,xyz}'. If a == b,
|
||||
returns a.
|
||||
"""
|
||||
if a == b:
|
||||
return a
|
||||
num_common_chars = min(len(a), len(b))
|
||||
for i in range(num_common_chars):
|
||||
if a[i] != b[i]:
|
||||
num_common_chars = i
|
||||
break
|
||||
return '%s{%s%s%s}' % (a[:num_common_chars], a[num_common_chars:],
|
||||
separator, b[num_common_chars:])
|
||||
|
||||
def _get_comparison_keys(k: str) -> List[str]:
|
||||
"""
|
||||
Gets names of diagnostic objects to compare with this one (including itself).
|
||||
If k is "something.output" or "something.grad", will return ["something.output", "something.grad"]
|
||||
If k is "something.param_value" or "something.param_grad", will return
|
||||
"""
|
||||
ending_sets = [ ['.output', '.grad'], ['.output[0]', '.grad[0]'], ['.output[1]', '.grad[1]'],
|
||||
['.output[2]', '.grad[2]'], ['.param_value', '.param_grad'] ]
|
||||
for s in ending_sets:
|
||||
for end in s:
|
||||
if k.endswith(end):
|
||||
prefix = k[:-len(end)]
|
||||
return [ prefix + suffix for suffix in s]
|
||||
return [k]
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||
@ -430,14 +314,6 @@ class ModelDiagnostic(object):
|
||||
"""Print diagnostics for each tensor."""
|
||||
for k in sorted(self.diagnostics.keys()):
|
||||
self.diagnostics[k].print_diagnostics()
|
||||
for l in _get_comparison_keys(k):
|
||||
if l >= k: # this ensures we don't print redundant correlations
|
||||
# for (a,b) and (b,a), since they are symmetric.
|
||||
try:
|
||||
self.diagnostics[k].print_joint_diagnostics(
|
||||
self.diagnostics[l])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def attach_diagnostics(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user