Update hooks.py (#1564)

This commit is contained in:
zr_jin 2024-03-20 16:43:45 +08:00 committed by GitHub
parent 9bd30853ae
commit d5cd78a637
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,6 @@
# Copyright 2021-2022 Xiaomi Corporation (authors: Zengwei Yao, Daniel Povey) # Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
# Daniel Povey,
# Zengrui Jin,)
# #
# See ../../LICENSE for clarification regarding multiple authors # See ../../LICENSE for clarification regarding multiple authors
# #
@ -77,7 +79,13 @@ def register_inf_check_hooks(model: nn.Module) -> None:
if not torch.isfinite(grad.to(torch.float32).sum()): if not torch.isfinite(grad.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.param_grad is not finite") logging.warning(f"The sum of {_name}.param_grad is not finite")
try:
parameter.register_hook(param_backward_hook) parameter.register_hook(param_backward_hook)
except:
logging.warning(
f"Warning: could not register backward hook for parameter {name}, "
f"it might not be differentiable."
)
def _test_inf_check_hooks(): def _test_inf_check_hooks():