From 973dc1026d93c5ce551428459077187a3cd1e0a9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 19 Oct 2023 22:54:00 +0800 Subject: [PATCH] Make diagnostics.py more error-tolerant and have wider range of supported torch versions (#1234) --- icefall/diagnostics.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 700dc1500..ebf61784e 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -244,12 +244,22 @@ class TensorDiagnostic(object): if stats_type == "eigs": try: - eigs, _ = torch.symeig(stats) + if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'): + eigs, _ = torch.linalg.eigh(stats) + else: + eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print("Error getting eigenvalues, trying another method.") - eigs, _ = torch.eig(stats) - stats = eigs.norm(dim=1).sqrt() + print( + "Error getting eigenvalues, trying another method." + ) + if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'): + eigs, _ = torch.linalg.eig(stats) + eigs = eigs.abs() + else: + eigs, _ = torch.eig(stats) + eigs = eigs.norm(dim=1) + stats = eigs.sqrt() # sqrt so it reflects data magnitude, like stddev- not variance if stats_type in ["rms", "stddev"]: @@ -569,11 +579,10 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in (torch.float32, torch.float16, torch.float64): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate( - o, class_name=get_class_name(_module) - ) - + if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, + class_name=get_class_name(_module)) + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] @@ -587,11 +596,9 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in (torch.float32, torch.float16, torch.float64): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( - o, class_name=get_class_name(_module) - ) - + if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, + class_name=get_class_name(_module)) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook)