Make diagnostics.py more error-tolerant and have wider range of supported torch versions

This commit is contained in:
Daniel Povey 2023-08-31 12:42:53 +08:00
parent 4ab7d61008
commit 4a6035a4aa

View File

@ -247,14 +247,22 @@ class TensorDiagnostic(object):
if stats_type == "eigs": if stats_type == "eigs":
try: try:
eigs, _ = torch.symeig(stats) if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
eigs, _ = torch.linalg.eigh(stats)
else:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt() stats = eigs.abs().sqrt()
except: # noqa except: # noqa
print( print(
"Error getting eigenvalues, trying another method." "Error getting eigenvalues, trying another method."
) )
eigs, _ = torch.eig(stats) if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
stats = eigs.norm(dim=1).sqrt() eigs, _ = torch.linalg.eig(stats)
eigs = eigs.abs()
else:
eigs, _ = torch.eig(stats)
eigs = eigs.norm(dim=1)
stats = eigs.sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance # sqrt so it reflects data magnitude, like stddev- not variance
if stats_type in [ "rms", "stddev" ]: if stats_type in [ "rms", "stddev" ]:
@ -556,7 +564,7 @@ def attach_diagnostics(
class_name=get_class_name(_module)) class_name=get_class_name(_module))
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
if o.dtype in ( torch.float32, torch.float16, torch.float64 ): if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=get_class_name(_module)) class_name=get_class_name(_module))
@ -570,7 +578,7 @@ def attach_diagnostics(
class_name=get_class_name(_module)) class_name=get_class_name(_module))
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
if o.dtype in ( torch.float32, torch.float16, torch.float64 ): if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=get_class_name(_module)) class_name=get_class_name(_module))