mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Remove unnecessary option for diagnostics code, collect on more batches
This commit is contained in:
parent
c0fdfabaf3
commit
c736b39c7d
@ -689,7 +689,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -831,10 +831,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
|
@ -695,7 +695,7 @@ def train_one_epoch(
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -839,10 +839,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
|
@ -767,7 +767,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -938,10 +938,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
|
||||
|
||||
|
@ -724,7 +724,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -888,10 +888,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
|
@ -523,7 +523,7 @@ def train_one_epoch(
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
@ -635,10 +635,7 @@ def run(rank, world_size, args):
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
|
@ -511,7 +511,7 @@ def train_one_epoch(
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
@ -623,10 +623,7 @@ def run(rank, world_size, args):
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
|
@ -690,7 +690,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
@ -832,10 +832,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
spgispeech = SPGISpeechAsrDataModule(args)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user