mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Update diagnostics.py (#1562)
This commit is contained in:
parent
413220d6a4
commit
9bd30853ae
@ -1,6 +1,7 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey
|
||||
# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
|
||||
# Zengwei Yao
|
||||
# Mingshuang Luo)
|
||||
# Mingshuang Luo,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -16,9 +17,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -653,7 +655,13 @@ def attach_diagnostics(
|
||||
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
|
||||
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
|
||||
|
||||
parameter.register_hook(param_backward_hook)
|
||||
try:
|
||||
parameter.register_hook(param_backward_hook)
|
||||
except:
|
||||
logging.warning(
|
||||
f"Warning: could not register backward hook for parameter {name}, "
|
||||
f"it might not be differentiable."
|
||||
)
|
||||
|
||||
return ans
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user