From 13db33ffa2dba26a528748979fa202b6949fc0e6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 15:53:53 +0800 Subject: [PATCH] Fix diagnostics-getting code --- .../ASR/pruned_transducer_stateless2/train.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 51858448d..b7cd45334 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -115,7 +115,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless2/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -556,6 +556,9 @@ def train_one_epoch( optimizer.step() optimizer.zero_grad() + if params.print_diagnostics and batch_idx == 5: + return + if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " @@ -665,7 +668,11 @@ def run(rank, world_size, args): if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + librispeech = LibriSpeechAsrDataModule(args)