diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index f95ce86b1..8ebd13734 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -873,13 +873,14 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + if params.start_batch <= 0: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index aa78bc412..f8e41271e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -924,7 +924,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 3b9fb710c..144a429cc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -1040,14 +1040,15 @@ def run(rank, world_size, args): # It's time consuming to include `giga_train_dl` here # for dl in [train_dl, giga_train_dl]: for dl in [train_dl]: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=dl, - optimizer=optimizer, - sp=sp, - params=params, - warmup=0.0 if params.start_epoch == 0 else 1.0, - ) + if params.start_batch <= 0: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 0 else 1.0, + ) scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 8c58a97e3..d2c7fb3d9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -973,7 +973,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index a8cbff44e..dca162d26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -962,7 +962,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index c054527ca..d6fbea357 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -507,9 +507,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -754,13 +751,7 @@ def train_one_epoch( tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -802,7 +793,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -815,7 +805,6 @@ def train_one_epoch( scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -990,7 +979,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl,