Remove comparison diagnostics, which were not that useful.

This commit is contained in:
Daniel Povey 2022-10-22 13:57:00 +08:00
parent 2e93e5d3b7
commit 8d1021d131

View File

@ -287,122 +287,6 @@ class TensorDiagnostic(object):
)
def print_joint_diagnostics(self, other: 'TensorDiagnostic'):
"""
Prints diagnostics that relate to correlations between the 'basic' diagnostics
printed in print_diagnostics().
"""
combined_name = _summarize_two_names(self.name, other.name)
# e.g. combined_name == 'foo.{param_value,param_grad}' or just 'foo.param_value' if self.name == other.name.
if self.stats is None:
return
for dim, this_dim_stats in enumerate(self.stats):
try:
other_dim_stats = other.stats[dim]
except (TypeError, IndexError):
print(f"Continuing, dim={dim}, (0)")
continue
output_list = []
for stats_type, stats_list in this_dim_stats.items():
# stats_type could be "rms", "value", "abs", "eigs", "positive".
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
# shape of the stats), or None
if stats_list is None:
continue
# work out `size_str`, will be used to print out data later.. this is the
# same for all `stats_type` values
sizes = [x.tensor.shape[0] for x in stats_list]
size_str = (
f"{sizes[0]}"
if len(sizes) == 1
else f"{min(sizes)}..{max(sizes)}"
)
if len(stats_list) == 1:
stats = stats_list[0].tensor / stats_list[0].count
else:
stats = torch.cat(
[x.tensor / x.count for x in stats_list], dim=0
)
other_stats_list = other_dim_stats[stats_type]
for other_stats_type, other_stats_list in other_dim_stats.items():
# avoid redundantly comparing a,b and b,a
if (other_stats_type > stats_type or other_stats_list is None or
len(other_stats_list) == 0 or stats_list is other_stats_list):
continue
if len(stats_list) == 1:
other_stats = other_stats_list[0].tensor / other_stats_list[0].count
size = stats.shape[0]
else:
other_stats = torch.cat(
[x.tensor / x.count for x in other_stats_list], dim=0
)
if other_stats.shape != stats.shape:
# e.g. stats_type == "eigs" and other_stats_type !=
# "eigs" or the other way around
continue
if stats.ndim == 2:
# Matrices, for purposes of measuring eigenvalues. Just compute a dot-product-related
# measure of correlation of eigenvalues.
def norm_diag_mean(x):
return x - torch.eye(x.shape[0], dtype=x.dtype, device=x.device) * x.diag().mean()
stats = norm_diag_mean(stats)
other_stats = norm_diag_mean(other_stats)
correlation = ((stats * other_stats).sum() /
((stats**2).sum() * (other_stats**2).sum() + 1.0e-20).sqrt())
else:
# ndim == 1
# Use a rank-based measure of correlation
(_, indices1) = stats.sort()
(_, indices2) = other_stats.sort()
n = stats.numel()
rank1 = ((indices1 + 0.5) / n) - 0.5
rank2 = ((indices2 + 0.5) / n) - 0.5
correlation = (rank1 * rank2).sum() / (rank1 * rank1).sum()
output_list.append(f'{stats_type},{other_stats_type}={correlation:.3f}')
if len(output_list) == 0:
continue
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
output = f"module={combined_name}{maybe_class_name} dim={dim} size={size_str}: " + ' '.join(output_list)
print(output)
def _summarize_two_names(a: str, b:str, separator: str = ',') -> str:
"""
Given 'foo.ab' and 'foo.xyz', returns 'foo.{ab,xyz}'. If a == b,
returns a.
"""
if a == b:
return a
num_common_chars = min(len(a), len(b))
for i in range(num_common_chars):
if a[i] != b[i]:
num_common_chars = i
break
return '%s{%s%s%s}' % (a[:num_common_chars], a[num_common_chars:],
separator, b[num_common_chars:])
def _get_comparison_keys(k: str) -> List[str]:
"""
Gets names of diagnostic objects to compare with this one (including itself).
If k is "something.output" or "something.grad", will return ["something.output", "something.grad"]
If k is "something.param_value" or "something.param_grad", will return
"""
ending_sets = [ ['.output', '.grad'], ['.output[0]', '.grad[0]'], ['.output[1]', '.grad[1]'],
['.output[2]', '.grad[2]'], ['.param_value', '.param_grad'] ]
for s in ending_sets:
for end in s:
if k.endswith(end):
prefix = k[:-len(end)]
return [ prefix + suffix for suffix in s]
return [k]
class ModelDiagnostic(object):
"""This class stores diagnostics for all tensors in the torch.nn.Module.
@ -430,14 +314,6 @@ class ModelDiagnostic(object):
"""Print diagnostics for each tensor."""
for k in sorted(self.diagnostics.keys()):
self.diagnostics[k].print_diagnostics()
for l in _get_comparison_keys(k):
if l >= k: # this ensures we don't print redundant correlations
# for (a,b) and (b,a), since they are symmetric.
try:
self.diagnostics[k].print_joint_diagnostics(
self.diagnostics[l])
except KeyError:
pass
def attach_diagnostics(