Fixes after review.

This commit is contained in:
Fangjun Kuang 2022-05-05 12:47:24 +08:00
parent ce885c6a67
commit a0dbfba77d

View File

@ -691,9 +691,8 @@ def train_one_epoch(
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
display_and_save_batch(batch, params=params, sp=sp)
except:
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5: