diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index d1cdfa8bb..267316691 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -606,6 +606,23 @@ def run(rank, world_size, args): train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() + from lhotse.dataset import find_pessimistic_batches + logging.info('Sanity check -- see if any of the batches in epoch 0 would cause OOM.') + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + logging.info(f'* criterion: {criterion} (={crit_values[criterion]}) ...') + batch = train_dl.dataset[cuts] + try: + compute_loss(params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True) + logging.info('OK!') + except RuntimeError as e: + if 'CUDA out of memory' in str(e): + logging.error( + 'Your GPU ran out of memory with the current max_duration setting. ' + 'We recommend decreasing max_duration and trying again.' + ) + raise + for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch)