Skip batches on load

This commit is contained in:
Nickolay Shmyrev 2023-07-03 18:32:48 +02:00
parent a4f6069029
commit 21a1be55b5

View File

@ -302,6 +302,9 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -457,7 +460,14 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
x, y, sentence_lengths = batch x, y, sentence_lengths = batch
batch_size = x.size(0) batch_size = x.size(0)
@ -482,6 +492,7 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -490,6 +501,7 @@ def train_one_epoch(
optimizer=optimizer, optimizer=optimizer,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens" # Note: "frames" here means "num_tokens"