Remove unnecessary option for diagnostics code, collect on more batches

This commit is contained in:
Daniel Povey 2022-05-19 11:35:54 +08:00
parent c0fdfabaf3
commit c736b39c7d
7 changed files with 14 additions and 35 deletions

View File

@ -689,7 +689,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -831,10 +831,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
gigaspeech = GigaSpeechAsrDataModule(args) gigaspeech = GigaSpeechAsrDataModule(args)

View File

@ -695,7 +695,7 @@ def train_one_epoch(
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -839,10 +839,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)

View File

@ -767,7 +767,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -938,10 +938,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)

View File

@ -724,7 +724,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -888,10 +888,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)

View File

@ -523,7 +523,7 @@ def train_one_epoch(
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
@ -635,10 +635,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:

View File

@ -511,7 +511,7 @@ def train_one_epoch(
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
@ -623,10 +623,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:

View File

@ -690,7 +690,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -832,10 +832,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
spgispeech = SPGISpeechAsrDataModule(args) spgispeech = SPGISpeechAsrDataModule(args)