Update hooks.py

This commit is contained in:
jinzr 2024-03-20 16:40:40 +08:00
parent 9bd30853ae
commit 0e726d1526

View File

@ -1,4 +1,6 @@
# Copyright 2021-2022 Xiaomi Corporation (authors: Zengwei Yao, Daniel Povey)
# Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
# Daniel Povey,
# Zengrui Jin,)
#
# See ../../LICENSE for clarification regarding multiple authors
#
@ -77,7 +79,13 @@ def register_inf_check_hooks(model: nn.Module) -> None:
if not torch.isfinite(grad.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.param_grad is not finite")
parameter.register_hook(param_backward_hook)
try:
parameter.register_hook(param_backward_hook)
except:
logging.warning(
f"Warning: could not register backward hook for parameter {name}, "
f"it might not be differentiable."
)
def _test_inf_check_hooks():