Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-08-15 11:45:53 +08:00
parent 72c0220830
commit 14e0886559

View File

@ -194,7 +194,10 @@ def load_checkpoint_if_available(
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename, model=model, optimizer=optimizer, scheduler=scheduler,
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
@ -512,6 +515,7 @@ def train_one_epoch(
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
if batch_idx == 0:
logging.info("save a batch for OOM handling")
# Use this batch to replace the batch that's causing OOM
params.saved_batch = batch
@ -597,7 +601,9 @@ def train_one_epoch(
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss", tot_avg_loss, params.batch_idx_train,
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
@ -646,6 +652,9 @@ def train_one_epoch(
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
if "saved_batch" in params:
del params["saved_batch"]
def run(rank, world_size, args):
"""
@ -749,10 +758,12 @@ def run(rank, world_size, args):
tb_writer=tb_writer,
world_size=world_size,
)
del params.saved_batch
save_checkpoint(
params=params, model=model, optimizer=optimizer, rank=rank,
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")