mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix to diagnostics
This commit is contained in:
parent
0191e8f3e4
commit
54f087fead
@ -522,6 +522,8 @@ def attach_diagnostics(
|
||||
if name == "":
|
||||
name = "<top-level>"
|
||||
|
||||
|
||||
|
||||
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
||||
# capture the variables, ensures that we use the current values.
|
||||
# (this matters for `name`, since the variable gets overwritten).
|
||||
@ -533,11 +535,12 @@ def attach_diagnostics(
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
|
||||
if isinstance(_output, Tensor):
|
||||
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.output"].accumulate(_output,
|
||||
class_name=type(_module).__name__)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||
class_name=type(_module).__name__)
|
||||
|
||||
@ -546,11 +549,12 @@ def attach_diagnostics(
|
||||
):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
if isinstance(_output, Tensor):
|
||||
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
|
||||
class_name=type(_module).__name__)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||
class_name=type(_module).__name__)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user