Add inf check hooks

This commit is contained in:
Daniel Povey 2022-10-22 17:16:29 +08:00
parent e8066b5825
commit 525e87a82d

View File

@ -79,6 +79,7 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -308,6 +309,13 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.", help="Accumulate stats on activations, print them and exit.",
) )
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
@ -992,6 +1000,9 @@ def run(rank, world_size, args):
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()