Update train.py

This commit is contained in:
Mingshuang Luo 2022-04-11 21:19:45 +08:00 committed by GitHub
parent 1b854e5c44
commit ef0b6df8f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -677,7 +677,7 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step) warmup=(params.batch_idx_train / params.model_warm_step),
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -727,7 +727,7 @@ def train_one_epoch(
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train "train/learning_rate", cur_lr, params.batch_idx_train
) )
loss_info.write_summary( loss_info.write_summary(
@ -959,7 +959,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0 warmup=0.0,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()