diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 2a937c6b8..e596c0028 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -543,6 +543,10 @@ def attach_diagnostics( _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, class_name=type(_module).__name__) + + module.register_forward_hook(forward_hook) + module.register_backward_hook(backward_hook) + if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish"]: # For these specific module types, accumulate some additional diagnostics # that can help us improve the activation function. These require a lot of memory,