Merge pull request #82 from pzelasko/feature/find-pessimistic-batches

Find CUDA OOM batches before starting training
This commit is contained in:
Piotr Żelasko 2021-10-19 11:26:13 -04:00 committed by GitHub
commit 902e0b238d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -606,6 +606,14 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
@ -646,6 +654,45 @@ def run(rank, world_size, args):
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)