mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Introduce backprop in finding OOM batches
This commit is contained in:
parent
060117a9ff
commit
403d1744ff
@ -618,13 +618,17 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
compute_loss(
|
optimizer.zero_grad()
|
||||||
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
|
loss.backward()
|
||||||
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
|
optimizer.step()
|
||||||
logging.info("OK!")
|
logging.info("OK!")
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "CUDA out of memory" in str(e):
|
if "CUDA out of memory" in str(e):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user