mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Fix diagnostics-getting code
This commit is contained in:
parent
11bea4513e
commit
13db33ffa2
@ -115,7 +115,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless/exp",
|
default="pruned_transducer_stateless2/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -556,6 +556,9 @@ def train_one_epoch(
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
|
return
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
@ -665,7 +668,11 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
diagnostic = diagnostics.attach_diagnostics(model)
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
|
2 ** 22
|
||||||
|
) # allow 4 megabytes per sub-module
|
||||||
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user