mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update hooks.py (#1564)
This commit is contained in:
parent
9bd30853ae
commit
d5cd78a637
@ -1,4 +1,6 @@
|
|||||||
# Copyright 2021-2022 Xiaomi Corporation (authors: Zengwei Yao, Daniel Povey)
|
# Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
|
||||||
|
# Daniel Povey,
|
||||||
|
# Zengrui Jin,)
|
||||||
#
|
#
|
||||||
# See ../../LICENSE for clarification regarding multiple authors
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -77,7 +79,13 @@ def register_inf_check_hooks(model: nn.Module) -> None:
|
|||||||
if not torch.isfinite(grad.to(torch.float32).sum()):
|
if not torch.isfinite(grad.to(torch.float32).sum()):
|
||||||
logging.warning(f"The sum of {_name}.param_grad is not finite")
|
logging.warning(f"The sum of {_name}.param_grad is not finite")
|
||||||
|
|
||||||
|
try:
|
||||||
parameter.register_hook(param_backward_hook)
|
parameter.register_hook(param_backward_hook)
|
||||||
|
except:
|
||||||
|
logging.warning(
|
||||||
|
f"Warning: could not register backward hook for parameter {name}, "
|
||||||
|
f"it might not be differentiable."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_inf_check_hooks():
|
def _test_inf_check_hooks():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user