From ed2d2932efdab3f5720e4aaba638f9d67b24c953 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 23 Mar 2022 12:02:06 +0800 Subject: [PATCH] Minor fixes for saving checkpoints. --- .../ASR/pruned_transducer_stateless/train.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index e71f0d1c6..309fa110f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -392,7 +392,6 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", - "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] @@ -546,6 +545,7 @@ def train_one_epoch( tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, + start_batch: int = 0, ) -> None: """Train the model for one epoch. @@ -571,6 +571,8 @@ def train_one_epoch( rank: The rank of the node in DDP training. If no DDP is used, it should be set to 0. + start_batch: + If not zero, it starts from this batch for training. """ model.train() @@ -615,13 +617,7 @@ def train_one_epoch( else: optimizer.step() - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - + for batch_idx, batch in enumerate(train_dl, start=start_batch): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -650,7 +646,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -660,7 +655,6 @@ def train_one_epoch( sampler=train_dl.sampler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -799,8 +793,10 @@ def run(rank, world_size, args): if checkpoints and "sampler" in checkpoints: sampler_state_dict = checkpoints["sampler"] + start_batch = sampler_state_dict["diagnostics"]["num_kept_batches"] else: sampler_state_dict = None + start_batch = 0 train_dl = librispeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict @@ -844,7 +840,9 @@ def run(rank, world_size, args): tb_writer=tb_writer, world_size=world_size, rank=rank, + start_batch=start_batch, ) + start_batch = 0 save_checkpoint( params=params,