Make diagnostics.py more error-tolerant and have wider range of supported torch versions (#1234)

This commit is contained in:
Daniel Povey 2023-10-19 22:54:00 +08:00 committed by GitHub
parent 543b4cc1ca
commit 973dc1026d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -244,12 +244,22 @@ class TensorDiagnostic(object):
if stats_type == "eigs":
try:
eigs, _ = torch.symeig(stats)
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
eigs, _ = torch.linalg.eigh(stats)
else:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs, _ = torch.eig(stats)
stats = eigs.norm(dim=1).sqrt()
print(
"Error getting eigenvalues, trying another method."
)
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
eigs, _ = torch.linalg.eig(stats)
eigs = eigs.abs()
else:
eigs, _ = torch.eig(stats)
eigs = eigs.norm(dim=1)
stats = eigs.sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
if stats_type in ["rms", "stddev"]:
@ -569,11 +579,10 @@ def attach_diagnostics(
)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if o.dtype in (torch.float32, torch.float16, torch.float64):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=get_class_name(_module))
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0]
@ -587,11 +596,9 @@ def attach_diagnostics(
)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if o.dtype in (torch.float32, torch.float16, torch.float64):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=get_class_name(_module))
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)