Remove unnecessary option for diagnostics code, collect on more batches

This commit is contained in:
Daniel Povey 2022-05-19 11:35:54 +08:00
parent c0fdfabaf3
commit c736b39c7d
7 changed files with 14 additions and 35 deletions

View File

@ -689,7 +689,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -831,10 +831,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
gigaspeech = GigaSpeechAsrDataModule(args)

View File

@ -695,7 +695,7 @@ def train_one_epoch(
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -839,10 +839,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
librispeech = LibriSpeechAsrDataModule(args)

View File

@ -767,7 +767,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -938,10 +938,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)

View File

@ -724,7 +724,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -888,10 +888,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
librispeech = LibriSpeechAsrDataModule(args)

View File

@ -523,7 +523,7 @@ def train_one_epoch(
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if batch_idx % params.log_interval == 0:
@ -635,10 +635,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:

View File

@ -511,7 +511,7 @@ def train_one_epoch(
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if batch_idx % params.log_interval == 0:
@ -623,10 +623,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:

View File

@ -690,7 +690,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
if params.print_diagnostics and batch_idx == 30:
return
if (
@ -832,10 +832,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diagnostic = diagnostics.attach_diagnostics(model)
spgispeech = SPGISpeechAsrDataModule(args)