mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Find CUDA OOM batches before starting training
This commit is contained in:
parent
fee1f84b20
commit
1c7c79f2fc
@ -606,6 +606,23 @@ def run(rank, world_size, args):
|
|||||||
train_dl = librispeech.train_dataloaders()
|
train_dl = librispeech.train_dataloaders()
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
|
|
||||||
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
logging.info('Sanity check -- see if any of the batches in epoch 0 would cause OOM.')
|
||||||
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
|
for criterion, cuts in batches.items():
|
||||||
|
logging.info(f'* criterion: {criterion} (={crit_values[criterion]}) ...')
|
||||||
|
batch = train_dl.dataset[cuts]
|
||||||
|
try:
|
||||||
|
compute_loss(params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True)
|
||||||
|
logging.info('OK!')
|
||||||
|
except RuntimeError as e:
|
||||||
|
if 'CUDA out of memory' in str(e):
|
||||||
|
logging.error(
|
||||||
|
'Your GPU ran out of memory with the current max_duration setting. '
|
||||||
|
'We recommend decreasing max_duration and trying again.'
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user