update diagnostics, print limits in Balancer, merge changes from Dan's branch zlm59 (#1109)

This commit is contained in:
Zengwei Yao 2023-06-01 14:24:19 +08:00 committed by GitHub
parent 03853f1ee5
commit 7a604057f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -498,6 +498,22 @@ class ModelDiagnostic(object):
self.diagnostics[k].print_diagnostics()
def get_class_name(module: nn.Module):
ans = type(module).__name__
# we put the below in try blocks in case anyone is using a different version of these modules that
# might have different member names.
if ans == 'Balancer' or ans == 'ActivationBalancer':
try:
ans += f'[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]'
except:
pass
elif ans == 'AbsValuePenalizer':
try:
ans += f'[{module.limit}]'
except:
pass
return ans
def attach_diagnostics(
model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
) -> ModelDiagnostic:
@ -537,12 +553,12 @@ def attach_diagnostics(
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output"].accumulate(_output,
class_name=type(_module).__name__)
class_name=get_class_name(_module))
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=type(_module).__name__)
class_name=get_class_name(_module))
def backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
@ -551,12 +567,12 @@ def attach_diagnostics(
_output = _output[0]
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
class_name=type(_module).__name__)
class_name=get_class_name(_module))
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=type(_module).__name__)
class_name=get_class_name(_module))
module.register_forward_hook(forward_hook)
@ -574,7 +590,7 @@ def attach_diagnostics(
_input, = _input
assert isinstance(_input, Tensor)
_model_diagnostic[f"{_name}.scalar"].accumulate_input(_input,
class_name=type(_module).__name__)
class_name=get_class_name(_module))
def scalar_backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name