mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Introduce backprop in finding OOM batches
This commit is contained in:
parent
060117a9ff
commit
403d1744ff
@ -618,13 +618,17 @@ def run(rank, world_size, args):
|
||||
)
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
compute_loss(
|
||||
optimizer.zero_grad()
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
logging.info("OK!")
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
|
Loading…
x
Reference in New Issue
Block a user