diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 1cd685d37..e8bedc64e 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -272,7 +272,7 @@ class ModelDiagnostic(object): def attach_diagnostics( - model: nn.Module, opts: TensorDiagnosticOptions + model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None ) -> ModelDiagnostic: """Attach a ModelDiagnostic object to the model by 1) registering forward hook and backward hook on each module, to accumulate @@ -335,7 +335,7 @@ def attach_diagnostics( def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2 ** 20, 512) + opts = TensorDiagnosticOptions(512) diagnostic = TensorDiagnostic(opts, "foo")