mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix diagnostics-getting code
This commit is contained in:
parent
11bea4513e
commit
13db33ffa2
@ -115,7 +115,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless/exp",
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -556,6 +556,9 @@ def train_one_epoch(
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
@ -665,7 +668,11 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user