Introduce backprop in finding OOM batches

This commit is contained in:
Piotr Żelasko 2021-10-15 10:05:13 -04:00
parent 060117a9ff
commit 403d1744ff

View File

@ -618,13 +618,17 @@ def run(rank, world_size, args):
)
batch = train_dl.dataset[cuts]
try:
compute_loss(
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
logging.info("OK!")
except RuntimeError as e:
if "CUDA out of memory" in str(e):