mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix to diagnostics
This commit is contained in:
parent
0191e8f3e4
commit
54f087fead
@ -522,6 +522,8 @@ def attach_diagnostics(
|
|||||||
if name == "":
|
if name == "":
|
||||||
name = "<top-level>"
|
name = "<top-level>"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
||||||
# capture the variables, ensures that we use the current values.
|
# capture the variables, ensures that we use the current values.
|
||||||
# (this matters for `name`, since the variable gets overwritten).
|
# (this matters for `name`, since the variable gets overwritten).
|
||||||
@ -533,11 +535,12 @@ def attach_diagnostics(
|
|||||||
if isinstance(_output, tuple) and len(_output) == 1:
|
if isinstance(_output, tuple) and len(_output) == 1:
|
||||||
_output = _output[0]
|
_output = _output[0]
|
||||||
|
|
||||||
if isinstance(_output, Tensor):
|
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||||
_model_diagnostic[f"{_name}.output"].accumulate(_output,
|
_model_diagnostic[f"{_name}.output"].accumulate(_output,
|
||||||
class_name=type(_module).__name__)
|
class_name=type(_module).__name__)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
|
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||||
class_name=type(_module).__name__)
|
class_name=type(_module).__name__)
|
||||||
|
|
||||||
@ -546,11 +549,12 @@ def attach_diagnostics(
|
|||||||
):
|
):
|
||||||
if isinstance(_output, tuple) and len(_output) == 1:
|
if isinstance(_output, tuple) and len(_output) == 1:
|
||||||
_output = _output[0]
|
_output = _output[0]
|
||||||
if isinstance(_output, Tensor):
|
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
|
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
|
||||||
class_name=type(_module).__name__)
|
class_name=type(_module).__name__)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
|
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||||
class_name=type(_module).__name__)
|
class_name=type(_module).__name__)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user