mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Skip batches on load
This commit is contained in:
parent
a4f6069029
commit
21a1be55b5
@ -302,6 +302,9 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -457,7 +460,14 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
x, y, sentence_lengths = batch
|
||||
batch_size = x.size(0)
|
||||
@ -482,6 +492,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -490,6 +501,7 @@ def train_one_epoch(
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
# Note: "frames" here means "num_tokens"
|
||||
|
Loading…
x
Reference in New Issue
Block a user