From c736b39c7d20c37faabe7b511566a0d9b584de03 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 19 May 2022 11:35:54 +0800 Subject: [PATCH] Remove unnecessary option for diagnostics code, collect on more batches --- egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py | 7 ++----- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 7 ++----- egs/librispeech/ASR/pruned_transducer_stateless3/train.py | 7 ++----- egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 7 ++----- egs/librispeech/ASR/transducer_stateless/train.py | 7 ++----- egs/librispeech/ASR/transducer_stateless2/train.py | 7 ++----- egs/spgispeech/ASR/pruned_transducer_stateless2/train.py | 7 ++----- 7 files changed, 14 insertions(+), 35 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 4421ce2aa..83ae25561 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -689,7 +689,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -831,10 +831,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) gigaspeech = GigaSpeechAsrDataModule(args) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 51c1a231a..eed2df755 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -695,7 +695,7 @@ def train_one_epoch( display_and_save_batch(batch, params=params, sp=sp) raise - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -839,10 +839,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) librispeech = LibriSpeechAsrDataModule(args) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 037f99bc7..f5a25a226 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -767,7 +767,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -938,10 +938,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) librispeech = LibriSpeech(manifest_dir=args.manifest_dir) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 4ff69d521..ca7207122 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -724,7 +724,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -888,10 +888,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) librispeech = LibriSpeechAsrDataModule(args) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 89f754b20..cb7f08a09 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -523,7 +523,7 @@ def train_one_epoch( loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if batch_idx % params.log_interval == 0: @@ -635,10 +635,7 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py index 8ceffb489..cb13e317c 100755 --- a/egs/librispeech/ASR/transducer_stateless2/train.py +++ b/egs/librispeech/ASR/transducer_stateless2/train.py @@ -511,7 +511,7 @@ def train_one_epoch( loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if batch_idx % params.log_interval == 0: @@ -623,10 +623,7 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py index 6c66bfb62..dda29b3e5 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py @@ -690,7 +690,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.print_diagnostics and batch_idx == 5: + if params.print_diagnostics and batch_idx == 30: return if ( @@ -832,10 +832,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) + diagnostic = diagnostics.attach_diagnostics(model) spgispeech = SPGISpeechAsrDataModule(args)