mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
Minor fixes for saving checkpoints.
This commit is contained in:
parent
8c7995d493
commit
ed2d2932ef
@ -392,7 +392,6 @@ def load_checkpoint_if_available(
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
"cur_batch_idx",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
@ -546,6 +545,7 @@ def train_one_epoch(
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
start_batch: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
@ -571,6 +571,8 @@ def train_one_epoch(
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
start_batch:
|
||||
If not zero, it starts from this batch for training.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
@ -615,13 +617,7 @@ def train_one_epoch(
|
||||
else:
|
||||
optimizer.step()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl, start=start_batch):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -650,7 +646,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -660,7 +655,6 @@ def train_one_epoch(
|
||||
sampler=train_dl.sampler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
@ -799,8 +793,10 @@ def run(rank, world_size, args):
|
||||
|
||||
if checkpoints and "sampler" in checkpoints:
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
start_batch = sampler_state_dict["diagnostics"]["num_kept_batches"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
start_batch = 0
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
@ -844,7 +840,9 @@ def run(rank, world_size, args):
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
start_batch=start_batch,
|
||||
)
|
||||
start_batch = 0
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
|
Loading…
x
Reference in New Issue
Block a user