warn instead of raising exceptions in inf-check

This commit is contained in:
Han Zhu 2024-12-30 13:58:25 +08:00
parent ad966fb81d
commit 5b598ea285

View File

@ -40,8 +40,8 @@ def register_inf_check_hooks(model: nn.Module) -> None:
def forward_hook(_module, _input, _output, _name=name): def forward_hook(_module, _input, _output, _name=name):
if isinstance(_output, Tensor): if isinstance(_output, Tensor):
if not torch.isfinite(_output.to(torch.float32).sum()): if not torch.isfinite(_output.to(torch.float32).sum()):
raise ValueError( logging.warning(
f"The sum of {_name}.output is not finite: {_output}" f"The sum of {_name}.output is not finite"
) )
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
@ -50,8 +50,8 @@ def register_inf_check_hooks(model: nn.Module) -> None:
if not isinstance(o, Tensor): if not isinstance(o, Tensor):
continue continue
if not torch.isfinite(o.to(torch.float32).sum()): if not torch.isfinite(o.to(torch.float32).sum()):
raise ValueError( logging.warning(
f"The sum of {_name}.output[{i}] is not finite: {_output}" f"The sum of {_name}.output[{i}] is not finite"
) )
# default param _name is a way to capture the current value of the variable "name". # default param _name is a way to capture the current value of the variable "name".