mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Update diagnostics.py (#1562)
This commit is contained in:
parent
413220d6a4
commit
9bd30853ae
@ -1,6 +1,7 @@
|
|||||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey
|
# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
|
||||||
# Zengwei Yao
|
# Zengwei Yao
|
||||||
# Mingshuang Luo)
|
# Mingshuang Luo,
|
||||||
|
# Zengrui Jin,)
|
||||||
#
|
#
|
||||||
# See ../LICENSE for clarification regarding multiple authors
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -16,9 +17,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -653,7 +655,13 @@ def attach_diagnostics(
|
|||||||
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
|
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
|
||||||
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
|
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
|
||||||
|
|
||||||
|
try:
|
||||||
parameter.register_hook(param_backward_hook)
|
parameter.register_hook(param_backward_hook)
|
||||||
|
except:
|
||||||
|
logging.warning(
|
||||||
|
f"Warning: could not register backward hook for parameter {name}, "
|
||||||
|
f"it might not be differentiable."
|
||||||
|
)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user