From 21a1be55b5ce75013b7dabce000a1f829d442fb4 Mon Sep 17 00:00:00 2001 From: Nickolay Shmyrev Date: Mon, 3 Jul 2023 18:32:48 +0200 Subject: [PATCH] Skip batches on load --- icefall/rnn_lm/train.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 80f8c238f..3d206d139 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -302,6 +302,9 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + return saved_params @@ -457,7 +460,14 @@ def train_one_epoch( tot_loss = MetricsTracker() + cur_batch_idx = params.get("cur_batch_idx", 0) + for batch_idx, batch in enumerate(train_dl): + + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) @@ -482,6 +492,7 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): + params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -490,6 +501,7 @@ def train_one_epoch( optimizer=optimizer, rank=rank, ) + del params.cur_batch_idx if batch_idx % params.log_interval == 0: # Note: "frames" here means "num_tokens"