mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update diagnostics stats
This commit is contained in:
parent
fe595f8772
commit
b9696878b4
@ -19,7 +19,7 @@
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -188,14 +188,22 @@ class TensorDiagnostic(object):
|
||||
for dim, this_dim_stats in enumerate(self.stats):
|
||||
for stats_type, stats_list in this_dim_stats.items():
|
||||
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
||||
# "value" could be a list of TensorAndCount, or None
|
||||
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
|
||||
# shape of the stats), or None
|
||||
if stats_list is None:
|
||||
assert stats_type == "eigs"
|
||||
continue
|
||||
|
||||
if stats_type == "eigs":
|
||||
assert len(stats_list) == 1
|
||||
if len(stats_list) == 1:
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
else:
|
||||
# a dimension that has variable size in different nnet
|
||||
# forwards, e.g. a time dimension in an ASR model.
|
||||
stats = torch.cat(
|
||||
[x.tensor / x.count for x in stats_list], dim=0
|
||||
)
|
||||
|
||||
if stats_type == "eigs":
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
@ -206,12 +214,6 @@ class TensorDiagnostic(object):
|
||||
eigs = torch.linalg.eigvals(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif len(stats_list) == 1:
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
else:
|
||||
stats = torch.cat(
|
||||
[x.tensor / x.count for x in stats_list], dim=0
|
||||
)
|
||||
|
||||
if stats_type == "rms":
|
||||
# we stored the square; after aggregation we need to take sqrt.
|
||||
@ -264,6 +266,117 @@ class TensorDiagnostic(object):
|
||||
)
|
||||
|
||||
|
||||
def print_joint_diagnostics(self, other: 'TensorDiagnostic'):
|
||||
"""
|
||||
Prints diagnostics that relate to correlations between the 'basic' diagnostics
|
||||
printed in print_diagnostics().
|
||||
"""
|
||||
combined_name = _summarize_two_names(self.name, other.name)
|
||||
# e.g. combined_name == 'foo.{param_value,param_grad}' or just 'foo.param_value' if self.name == other.name.
|
||||
for dim, this_dim_stats in enumerate(self.stats):
|
||||
try:
|
||||
other_dim_stats = other.stats[dim]
|
||||
except (TypeError, IndexError):
|
||||
print(f"Continuing, dim={dim}, (0)")
|
||||
continue
|
||||
|
||||
output_list = []
|
||||
for stats_type, stats_list in this_dim_stats.items():
|
||||
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
||||
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
|
||||
# shape of the stats), or None
|
||||
if stats_list is None:
|
||||
continue
|
||||
# work out `size_str`, will be used to print out data later.. this is the
|
||||
# same for all `stats_type` values
|
||||
sizes = [x.tensor.shape[0] for x in stats_list]
|
||||
size_str = (
|
||||
f"{sizes[0]}"
|
||||
if len(sizes) == 1
|
||||
else f"{min(sizes)}..{max(sizes)}"
|
||||
)
|
||||
|
||||
if len(stats_list) == 1:
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
else:
|
||||
stats = torch.cat(
|
||||
[x.tensor / x.count for x in stats_list], dim=0
|
||||
)
|
||||
|
||||
other_stats_list = other_dim_stats[stats_type]
|
||||
for other_stats_type, other_stats_list in other_dim_stats.items():
|
||||
# avoid redundantly comparing a,b and b,a
|
||||
if (other_stats_type > stats_type or other_stats_list is None or
|
||||
len(other_stats_list) == 0 or stats_list is other_stats_list):
|
||||
continue
|
||||
if len(stats_list) == 1:
|
||||
other_stats = other_stats_list[0].tensor / other_stats_list[0].count
|
||||
size = stats.shape[0]
|
||||
else:
|
||||
other_stats = torch.cat(
|
||||
[x.tensor / x.count for x in other_stats_list], dim=0
|
||||
)
|
||||
if other_stats.shape != stats.shape:
|
||||
# e.g. stats_type == "eigs" and other_stats_type !=
|
||||
# "eigs" or the other way around
|
||||
continue
|
||||
|
||||
|
||||
if stats.ndim == 2:
|
||||
# Matrices, for purposes of measuring eigenvalues. Just compute a dot-product-related
|
||||
# measure of correlation.
|
||||
correlation = ((stats * other_stats).sum() /
|
||||
((stats**2).sum() * (other_stats**2).sum() + 1.0e-20).sqrt())
|
||||
else:
|
||||
# ndim == 1
|
||||
# Use a rank-based measure of correlation
|
||||
(_, indices1) = stats.sort()
|
||||
(_, indices2) = other_stats.sort()
|
||||
n = stats.numel()
|
||||
rank1 = ((indices1 + 0.5) / n) - 0.5
|
||||
rank2 = ((indices2 + 0.5) / n) - 0.5
|
||||
correlation = (rank1 * rank2).sum() / (rank1 * rank1).sum()
|
||||
output_list.append(f'{stats_type},{other_stats_type}={correlation:.3f}')
|
||||
if len(output_list) == 0:
|
||||
continue
|
||||
|
||||
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
||||
output = f"module={combined_name}{maybe_class_name} dim={dim} size={size_str}: " + ' '.join(output_list)
|
||||
print(output)
|
||||
|
||||
|
||||
|
||||
def _summarize_two_names(a: str, b:str, separator: str = ',') -> str:
|
||||
"""
|
||||
Given 'foo.ab' and 'foo.xyz', returns 'foo.{ab,xyz}'. If a == b,
|
||||
returns a.
|
||||
"""
|
||||
if a == b:
|
||||
return a
|
||||
num_common_chars = min(len(a), len(b))
|
||||
for i in range(num_common_chars):
|
||||
if a[i] != b[i]:
|
||||
num_common_chars = i
|
||||
break
|
||||
return '%s{%s%s%s}' % (a[:num_common_chars], a[num_common_chars:],
|
||||
separator, b[num_common_chars:])
|
||||
|
||||
def _get_comparison_keys(k: str) -> List[str]:
|
||||
"""
|
||||
Gets names of diagnostic objects to compare with this one (including itself).
|
||||
If k is "something.output" or "something.grad", will return ["something.output", "something.grad"]
|
||||
If k is "something.param_value" or "something.param_grad", will return
|
||||
"""
|
||||
ending_sets = [ ['.output', '.grad'], ['.output[0]', '.grad[0]'], ['.output[1]', '.grad[1]'],
|
||||
['.output[2]', '.grad[2]'], ['.param_value', '.param_grad'] ]
|
||||
for s in ending_sets:
|
||||
for end in s:
|
||||
if k.endswith(end):
|
||||
prefix = k[:-len(end)]
|
||||
return [ prefix + suffix for suffix in s]
|
||||
return [k]
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||
|
||||
@ -290,6 +403,14 @@ class ModelDiagnostic(object):
|
||||
"""Print diagnostics for each tensor."""
|
||||
for k in sorted(self.diagnostics.keys()):
|
||||
self.diagnostics[k].print_diagnostics()
|
||||
for l in _get_comparison_keys(k):
|
||||
if l >= k: # this ensures we don't print redundant correlations
|
||||
# for (a,b) and (b,a), since they are symmetric.
|
||||
try:
|
||||
self.diagnostics[k].print_joint_diagnostics(
|
||||
self.diagnostics[l])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def attach_diagnostics(
|
||||
@ -324,6 +445,8 @@ def attach_diagnostics(
|
||||
def forward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
|
||||
if isinstance(_output, Tensor):
|
||||
_model_diagnostic[f"{_name}.output"].accumulate(_output,
|
||||
@ -336,6 +459,8 @@ def attach_diagnostics(
|
||||
def backward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
if isinstance(_output, Tensor):
|
||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
|
||||
class_name=type(_module).__name__)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user