mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Make diagnostics.py more error-tolerant and have wider range of supported torch versions (#1234)
This commit is contained in:
parent
543b4cc1ca
commit
973dc1026d
@ -244,12 +244,22 @@ class TensorDiagnostic(object):
|
|||||||
|
|
||||||
if stats_type == "eigs":
|
if stats_type == "eigs":
|
||||||
try:
|
try:
|
||||||
|
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
|
||||||
|
eigs, _ = torch.linalg.eigh(stats)
|
||||||
|
else:
|
||||||
eigs, _ = torch.symeig(stats)
|
eigs, _ = torch.symeig(stats)
|
||||||
stats = eigs.abs().sqrt()
|
stats = eigs.abs().sqrt()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
print("Error getting eigenvalues, trying another method.")
|
print(
|
||||||
|
"Error getting eigenvalues, trying another method."
|
||||||
|
)
|
||||||
|
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
|
||||||
|
eigs, _ = torch.linalg.eig(stats)
|
||||||
|
eigs = eigs.abs()
|
||||||
|
else:
|
||||||
eigs, _ = torch.eig(stats)
|
eigs, _ = torch.eig(stats)
|
||||||
stats = eigs.norm(dim=1).sqrt()
|
eigs = eigs.norm(dim=1)
|
||||||
|
stats = eigs.sqrt()
|
||||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||||
|
|
||||||
if stats_type in ["rms", "stddev"]:
|
if stats_type in ["rms", "stddev"]:
|
||||||
@ -569,10 +579,9 @@ def attach_diagnostics(
|
|||||||
)
|
)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
|
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||||
o, class_name=get_class_name(_module)
|
class_name=get_class_name(_module))
|
||||||
)
|
|
||||||
|
|
||||||
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||||
if isinstance(_output, tuple) and len(_output) == 1:
|
if isinstance(_output, tuple) and len(_output) == 1:
|
||||||
@ -587,11 +596,9 @@ def attach_diagnostics(
|
|||||||
)
|
)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
|
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||||
o, class_name=get_class_name(_module)
|
class_name=get_class_name(_module))
|
||||||
)
|
|
||||||
|
|
||||||
module.register_forward_hook(forward_hook)
|
module.register_forward_hook(forward_hook)
|
||||||
module.register_backward_hook(backward_hook)
|
module.register_backward_hook(backward_hook)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user