diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 17b4f9c4b..1fd059b79 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -79,6 +79,7 @@ from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) +from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -308,6 +309,13 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + parser.add_argument( "--save-every-n", type=int, @@ -992,6 +1000,9 @@ def run(rank, world_size, args): ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) + if params.inf_check: + register_inf_check_hooks(model) + librispeech = LibriSpeechAsrDataModule(args) train_cuts = librispeech.train_clean_100_cuts()