Fix diagnostics-getting code

This commit is contained in:
Daniel Povey 2022-03-17 15:53:53 +08:00
parent 11bea4513e
commit 13db33ffa2

View File

@ -115,7 +115,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless/exp", default="pruned_transducer_stateless2/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -556,6 +556,9 @@ def train_one_epoch(
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
@ -665,7 +668,11 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
diagnostic = diagnostics.attach_diagnostics(model) opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)