Update diagnostics.py (#1562)

This commit is contained in:
zr_jin 2024-03-20 15:35:14 +08:00 committed by GitHub
parent 413220d6a4
commit 9bd30853ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey # Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
# Zengwei Yao # Zengwei Yao
# Mingshuang Luo) # Mingshuang Luo,
# Zengrui Jin,)
# #
# See ../LICENSE for clarification regarding multiple authors # See ../LICENSE for clarification regarding multiple authors
# #
@ -16,9 +17,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -653,7 +655,13 @@ def attach_diagnostics(
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad) _model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook) try:
parameter.register_hook(param_backward_hook)
except:
logging.warning(
f"Warning: could not register backward hook for parameter {name}, "
f"it might not be differentiable."
)
return ans return ans