Fix to diagnostics

This commit is contained in:
Daniel Povey 2023-02-24 16:12:30 +08:00
parent 0191e8f3e4
commit 54f087fead

View File

@ -522,6 +522,8 @@ def attach_diagnostics(
if name == "": if name == "":
name = "<top-level>" name = "<top-level>"
# Setting model_diagnostic=ans and n=name below, instead of trying to # Setting model_diagnostic=ans and n=name below, instead of trying to
# capture the variables, ensures that we use the current values. # capture the variables, ensures that we use the current values.
# (this matters for `name`, since the variable gets overwritten). # (this matters for `name`, since the variable gets overwritten).
@ -533,11 +535,12 @@ def attach_diagnostics(
if isinstance(_output, tuple) and len(_output) == 1: if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0] _output = _output[0]
if isinstance(_output, Tensor): if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output"].accumulate(_output, _model_diagnostic[f"{_name}.output"].accumulate(_output,
class_name=type(_module).__name__) class_name=type(_module).__name__)
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=type(_module).__name__) class_name=type(_module).__name__)
@ -546,11 +549,12 @@ def attach_diagnostics(
): ):
if isinstance(_output, tuple) and len(_output) == 1: if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0] _output = _output[0]
if isinstance(_output, Tensor): if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad"].accumulate(_output, _model_diagnostic[f"{_name}.grad"].accumulate(_output,
class_name=type(_module).__name__) class_name=type(_module).__name__)
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=type(_module).__name__) class_name=type(_module).__name__)