From 7a604057f9c99dc5f064ed6557d17f2ba95c518a Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 1 Jun 2023 14:24:19 +0800 Subject: [PATCH] update diagnostics, print limits in Balancer, merge changes from Dan's branch zlm59 (#1109) --- icefall/diagnostics.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 51e816105..98870684e 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -498,6 +498,22 @@ class ModelDiagnostic(object): self.diagnostics[k].print_diagnostics() +def get_class_name(module: nn.Module): + ans = type(module).__name__ + # we put the below in try blocks in case anyone is using a different version of these modules that + # might have different member names. + if ans == 'Balancer' or ans == 'ActivationBalancer': + try: + ans += f'[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]' + except: + pass + elif ans == 'AbsValuePenalizer': + try: + ans += f'[{module.limit}]' + except: + pass + return ans + def attach_diagnostics( model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None ) -> ModelDiagnostic: @@ -537,12 +553,12 @@ def attach_diagnostics( if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.output"].accumulate(_output, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) elif isinstance(_output, tuple): for i, o in enumerate(_output): if o.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) def backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name @@ -551,12 +567,12 @@ def attach_diagnostics( _output = _output[0] if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.grad"].accumulate(_output, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) elif isinstance(_output, tuple): for i, o in enumerate(_output): if o.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) module.register_forward_hook(forward_hook) @@ -574,7 +590,7 @@ def attach_diagnostics( _input, = _input assert isinstance(_input, Tensor) _model_diagnostic[f"{_name}.scalar"].accumulate_input(_input, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) def scalar_backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name