Small fixes

This commit is contained in:
Daniel Povey 2022-05-19 12:49:00 +08:00
parent c736b39c7d
commit 5230e73e41

View File

@ -272,7 +272,7 @@ class ModelDiagnostic(object):
def attach_diagnostics(
model: nn.Module, opts: TensorDiagnosticOptions
model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
) -> ModelDiagnostic:
"""Attach a ModelDiagnostic object to the model by
1) registering forward hook and backward hook on each module, to accumulate
@ -335,7 +335,7 @@ def attach_diagnostics(
def _test_tensor_diagnostic():
opts = TensorDiagnosticOptions(2 ** 20, 512)
opts = TensorDiagnosticOptions(512)
diagnostic = TensorDiagnostic(opts, "foo")