Merge pull request #376 from danpovey/diagnostics_fix

Diagnostics fix
This commit is contained in:
Daniel Povey 2022-05-19 12:51:07 +08:00 committed by GitHub
commit 2900ed8f8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -272,7 +272,7 @@ class ModelDiagnostic(object):
def attach_diagnostics(
model: nn.Module, opts: TensorDiagnosticOptions
model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
) -> ModelDiagnostic:
"""Attach a ModelDiagnostic object to the model by
1) registering forward hook and backward hook on each module, to accumulate
@ -335,7 +335,7 @@ def attach_diagnostics(
def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2 ** 20, 512)
opts = TensorDiagnosticOptions(512)
diagnostic = TensorDiagnostic(opts, "foo")