Find CUDA OOM batches before starting training

This commit is contained in:
Piotr Żelasko 2021-10-14 21:28:11 -04:00
parent fee1f84b20
commit 1c7c79f2fc

View File

@ -606,6 +606,23 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders() train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() valid_dl = librispeech.valid_dataloaders()
from lhotse.dataset import find_pessimistic_batches
logging.info('Sanity check -- see if any of the batches in epoch 0 would cause OOM.')
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
logging.info(f'* criterion: {criterion} (={crit_values[criterion]}) ...')
batch = train_dl.dataset[cuts]
try:
compute_loss(params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True)
logging.info('OK!')
except RuntimeError as e:
if 'CUDA out of memory' in str(e):
logging.error(
'Your GPU ran out of memory with the current max_duration setting. '
'We recommend decreasing max_duration and trying again.'
)
raise
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)