# Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao, # Daniel Povey, # Zengrui Jin,) # # See ../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import random import torch from torch import Tensor, nn def register_inf_check_hooks(model: nn.Module) -> None: """Registering forward hook on each module, to check whether its output tensors is not finite. Args: model: the model to be analyzed. """ for name, module in model.named_modules(): if name == "": name = "" # default param _name is a way to capture the current value of the variable "name". def forward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): try: if not torch.isfinite(_output.to(torch.float32).sum()): logging.warning(f"The sum of {_name}.output is not finite") except RuntimeError: # e.g. CUDA out of memory pass elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): o = o[0] if not isinstance(o, Tensor): continue try: if not torch.isfinite(o.to(torch.float32).sum()): logging.warning( f"The sum of {_name}.output[{i}] is not finite" ) except RuntimeError: # e.g. CUDA out of memory pass # default param _name is a way to capture the current value of the variable "name". def backward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): try: if not torch.isfinite(_output.to(torch.float32).sum()): logging.warning(f"The sum of {_name}.grad is not finite") except RuntimeError: # e.g. CUDA out of memory pass elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): o = o[0] if not isinstance(o, Tensor): continue if not torch.isfinite(o.to(torch.float32).sum()): logging.warning(f"The sum of {_name}.grad[{i}] is not finite") module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) for name, parameter in model.named_parameters(): def param_backward_hook(grad, _name=name): if not torch.isfinite(grad.to(torch.float32).sum()): logging.warning(f"The sum of {_name}.param_grad is not finite") try: parameter.register_hook(param_backward_hook) except: logging.warning( f"Warning: could not register backward hook for parameter {name}, " f"it might not be differentiable." ) def _test_inf_check_hooks(): model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) register_inf_check_hooks(model) for _ in range(10): T = random.randint(200, 300) x = torch.randn(T, 100) + float("inf") * (T % 2) y = model(x) y.sum().backward() if __name__ == "__main__": _test_inf_check_hooks()