diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index a3c480c9c..37872f233 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -1,6 +1,7 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey +# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey # Zengwei Yao -# Mingshuang Luo) +# Mingshuang Luo, +# Zengrui Jin,) # # See ../LICENSE for clarification regarding multiple authors # @@ -16,9 +17,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import random from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -653,7 +655,13 @@ def attach_diagnostics( _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) - parameter.register_hook(param_backward_hook) + try: + parameter.register_hook(param_backward_hook) + except: + logging.warning( + f"Warning: could not register backward hook for parameter {name}, " + f"it might not be differentiable." + ) return ans