Fix diagnostics-getting code

This commit is contained in:
Daniel Povey 2022-03-17 15:53:53 +08:00
parent 11bea4513e
commit 13db33ffa2

View File

@ -115,7 +115,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
default="pruned_transducer_stateless2/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -556,6 +556,9 @@ def train_one_epoch(
optimizer.step()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
@ -665,7 +668,11 @@ def run(rank, world_size, args):
if params.print_diagnostics:
diagnostic = diagnostics.attach_diagnostics(model)
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args)