From 6fbd7a287c7ab5657fe3807764ec4d600f8a27e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 18 Oct 2021 09:53:04 -0400 Subject: [PATCH] Refactor OOM batch scanning into a local function --- egs/librispeech/ASR/conformer_ctc/train.py | 83 ++++++++++++---------- 1 file changed, 46 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 8e5a15285..de8e078f8 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -503,9 +503,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -606,47 +604,20 @@ def run(rank, world_size, args): train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - logging.info( - f"* criterion: {criterion} (={crit_values[criterion]}) ..." - ) - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - logging.info("OK!") - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again." - ) - raise for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: @@ -679,6 +650,44 @@ def run(rank, world_size, args): cleanup_dist() +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} (={crit_values[criterion]}) ..." + ) + raise + + def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser)