Fix bug in diagnostics.py

This commit is contained in:
Daniel Povey 2022-12-01 16:23:50 +08:00
parent 2969eb5467
commit 2102038e0e

View File

@ -543,6 +543,10 @@ def attach_diagnostics(
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=type(_module).__name__)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish"]:
# For these specific module types, accumulate some additional diagnostics
# that can help us improve the activation function. These require a lot of memory,