From 1c7c79f2fcf871033ca640d86740effc976f7a3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 14 Oct 2021 21:28:11 -0400 Subject: [PATCH 1/6] Find CUDA OOM batches before starting training --- egs/librispeech/ASR/conformer_ctc/train.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index d1cdfa8bb..267316691 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -606,6 +606,23 @@ def run(rank, world_size, args): train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() + from lhotse.dataset import find_pessimistic_batches + logging.info('Sanity check -- see if any of the batches in epoch 0 would cause OOM.') + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + logging.info(f'* criterion: {criterion} (={crit_values[criterion]}) ...') + batch = train_dl.dataset[cuts] + try: + compute_loss(params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True) + logging.info('OK!') + except RuntimeError as e: + if 'CUDA out of memory' in str(e): + logging.error( + 'Your GPU ran out of memory with the current max_duration setting. ' + 'We recommend decreasing max_duration and trying again.' + ) + raise + for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) From 060117a9ff3eed59d96c1bfa4f0ad3e2082fdac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 14 Oct 2021 21:40:14 -0400 Subject: [PATCH 2/6] Reformatting --- egs/librispeech/ASR/conformer_ctc/train.py | 26 ++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 267316691..9d3de020e 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -607,19 +607,31 @@ def run(rank, world_size, args): valid_dl = librispeech.valid_dataloaders() from lhotse.dataset import find_pessimistic_batches - logging.info('Sanity check -- see if any of the batches in epoch 0 would cause OOM.') + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): - logging.info(f'* criterion: {criterion} (={crit_values[criterion]}) ...') + logging.info( + f"* criterion: {criterion} (={crit_values[criterion]}) ..." + ) batch = train_dl.dataset[cuts] try: - compute_loss(params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True) - logging.info('OK!') + compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + logging.info("OK!") except RuntimeError as e: - if 'CUDA out of memory' in str(e): + if "CUDA out of memory" in str(e): logging.error( - 'Your GPU ran out of memory with the current max_duration setting. ' - 'We recommend decreasing max_duration and trying again.' + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again." ) raise From 403d1744fff4834bf316b1fd715bdceda62646fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 15 Oct 2021 10:05:13 -0400 Subject: [PATCH 3/6] Introduce backprop in finding OOM batches --- egs/librispeech/ASR/conformer_ctc/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 9d3de020e..db167fd9d 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -618,13 +618,17 @@ def run(rank, world_size, args): ) batch = train_dl.dataset[cuts] try: - compute_loss( + optimizer.zero_grad() + loss, _ = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() logging.info("OK!") except RuntimeError as e: if "CUDA out of memory" in str(e): From 6fbd7a287c7ab5657fe3807764ec4d600f8a27e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 18 Oct 2021 09:53:04 -0400 Subject: [PATCH 4/6] Refactor OOM batch scanning into a local function --- egs/librispeech/ASR/conformer_ctc/train.py | 83 ++++++++++++---------- 1 file changed, 46 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 8e5a15285..de8e078f8 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -503,9 +503,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -606,47 +604,20 @@ def run(rank, world_size, args): train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - logging.info( - f"* criterion: {criterion} (={crit_values[criterion]}) ..." - ) - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - logging.info("OK!") - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again." - ) - raise for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: @@ -679,6 +650,44 @@ def run(rank, world_size, args): cleanup_dist() +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} (={crit_values[criterion]}) ..." + ) + raise + + def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) From 86f3e0ef37a8e52508f175bb2e75561cd5c6f6c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 18 Oct 2021 09:54:40 -0400 Subject: [PATCH 5/6] Make flake8 happy --- egs/librispeech/ASR/conformer_ctc/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index de8e078f8..c2b9c1fb6 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -683,7 +683,8 @@ def scan_pessimistic_batches_for_oom( "Your GPU ran out of memory with the current " "max_duration setting. We recommend decreasing " "max_duration and trying again.\n" - f"Failing criterion: {criterion} (={crit_values[criterion]}) ..." + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." ) raise From 3cc99d2af2d7ebddca80770d8d192f82dcc7cb89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 19 Oct 2021 11:24:54 -0400 Subject: [PATCH 6/6] make flake8 happy --- egs/librispeech/ASR/conformer_ctc/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index c2b9c1fb6..223c8d993 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -503,7 +503,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -617,7 +619,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: