Update diagnostics.py (#1562)

This commit is contained in:
zr_jin 2024-03-20 15:35:14 +08:00 committed by GitHub
parent 413220d6a4
commit 9bd30853ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey
# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
# Zengwei Yao
# Mingshuang Luo)
# Mingshuang Luo,
# Zengrui Jin,)
#
# See ../LICENSE for clarification regarding multiple authors
#
@ -16,9 +17,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
@ -653,7 +655,13 @@ def attach_diagnostics(
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook)
try:
parameter.register_hook(param_backward_hook)
except:
logging.warning(
f"Warning: could not register backward hook for parameter {name}, "
f"it might not be differentiable."
)
return ans