Update train.py

This commit is contained in:
Mingshuang Luo 2022-04-11 21:19:45 +08:00 committed by GitHub
parent 1b854e5c44
commit ef0b6df8f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -677,7 +677,7 @@ def train_one_epoch(
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step)
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -727,7 +727,7 @@ def train_one_epoch(
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
@ -959,7 +959,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0
warmup=0.0,
)
loss.backward()
optimizer.step()