diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 309fa110f..0f51b4382 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -392,11 +392,13 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", + "cur_batch_idx", ] for k in keys: - params[k] = saved_params[k] + params[k] = saved_params.get(k, 0) - params["start_epoch"] = saved_params["cur_epoch"] + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] return saved_params @@ -545,7 +547,6 @@ def train_one_epoch( tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, - start_batch: int = 0, ) -> None: """Train the model for one epoch. @@ -571,8 +572,6 @@ def train_one_epoch( rank: The rank of the node in DDP training. If no DDP is used, it should be set to 0. - start_batch: - If not zero, it starts from this batch for training. """ model.train() @@ -617,7 +616,13 @@ def train_one_epoch( else: optimizer.step() - for batch_idx, batch in enumerate(train_dl, start=start_batch): + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -646,6 +651,7 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): + params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -655,6 +661,7 @@ def train_one_epoch( sampler=train_dl.sampler, rank=rank, ) + del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -793,10 +800,8 @@ def run(rank, world_size, args): if checkpoints and "sampler" in checkpoints: sampler_state_dict = checkpoints["sampler"] - start_batch = sampler_state_dict["diagnostics"]["num_kept_batches"] else: sampler_state_dict = None - start_batch = 0 train_dl = librispeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict @@ -840,9 +845,7 @@ def run(rank, world_size, args): tb_writer=tb_writer, world_size=world_size, rank=rank, - start_batch=start_batch, ) - start_batch = 0 save_checkpoint( params=params,