From 54f087fead31e416da90fa657f735eb382b54878 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 24 Feb 2023 16:12:30 +0800 Subject: [PATCH] Fix to diagnostics --- icefall/diagnostics.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 15ec28b3b..51e816105 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -522,6 +522,8 @@ def attach_diagnostics( if name == "": name = "" + + # Setting model_diagnostic=ans and n=name below, instead of trying to # capture the variables, ensures that we use the current values. # (this matters for `name`, since the variable gets overwritten). @@ -533,26 +535,28 @@ def attach_diagnostics( if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] - if isinstance(_output, Tensor): + if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.output"].accumulate(_output, class_name=type(_module).__name__) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=type(_module).__name__) + if o.dtype in ( torch.float32, torch.float16, torch.float64 ): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, + class_name=type(_module).__name__) def backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] - if isinstance(_output, Tensor): + if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.grad"].accumulate(_output, class_name=type(_module).__name__) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=type(_module).__name__) + if o.dtype in ( torch.float32, torch.float16, torch.float64 ): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, + class_name=type(_module).__name__) module.register_forward_hook(forward_hook)