Fix to diagnostics

This commit is contained in:
Daniel Povey 2023-02-24 16:12:30 +08:00
parent 0191e8f3e4
commit 54f087fead

View File

@ -522,6 +522,8 @@ def attach_diagnostics(
if name == "": if name == "":
name = "<top-level>" name = "<top-level>"
# Setting model_diagnostic=ans and n=name below, instead of trying to # Setting model_diagnostic=ans and n=name below, instead of trying to
# capture the variables, ensures that we use the current values. # capture the variables, ensures that we use the current values.
# (this matters for `name`, since the variable gets overwritten). # (this matters for `name`, since the variable gets overwritten).
@ -533,26 +535,28 @@ def attach_diagnostics(
if isinstance(_output, tuple) and len(_output) == 1: if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0] _output = _output[0]
if isinstance(_output, Tensor): if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output"].accumulate(_output, _model_diagnostic[f"{_name}.output"].accumulate(_output,
class_name=type(_module).__name__) class_name=type(_module).__name__)
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
class_name=type(_module).__name__) _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=type(_module).__name__)
def backward_hook( def backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name _module, _input, _output, _model_diagnostic=ans, _name=name
): ):
if isinstance(_output, tuple) and len(_output) == 1: if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0] _output = _output[0]
if isinstance(_output, Tensor): if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad"].accumulate(_output, _model_diagnostic[f"{_name}.grad"].accumulate(_output,
class_name=type(_module).__name__) class_name=type(_module).__name__)
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
class_name=type(_module).__name__) _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=type(_module).__name__)
module.register_forward_hook(forward_hook) module.register_forward_hook(forward_hook)