mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Make diagnostics.py more error-tolerant and have wider range of supported torch versions
This commit is contained in:
parent
4ab7d61008
commit
4a6035a4aa
@ -247,14 +247,22 @@ class TensorDiagnostic(object):
|
||||
|
||||
if stats_type == "eigs":
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
|
||||
eigs, _ = torch.linalg.eigh(stats)
|
||||
else:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print(
|
||||
"Error getting eigenvalues, trying another method."
|
||||
)
|
||||
eigs, _ = torch.eig(stats)
|
||||
stats = eigs.norm(dim=1).sqrt()
|
||||
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
|
||||
eigs, _ = torch.linalg.eig(stats)
|
||||
eigs = eigs.abs()
|
||||
else:
|
||||
eigs, _ = torch.eig(stats)
|
||||
eigs = eigs.norm(dim=1)
|
||||
stats = eigs.sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
|
||||
if stats_type in [ "rms", "stddev" ]:
|
||||
@ -556,7 +564,7 @@ def attach_diagnostics(
|
||||
class_name=get_class_name(_module))
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
|
||||
@ -570,7 +578,7 @@ def attach_diagnostics(
|
||||
class_name=get_class_name(_module))
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user