mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge pull request #82 from pzelasko/feature/find-pessimistic-batches
Find CUDA OOM batches before starting training
This commit is contained in:
commit
902e0b238d
@ -606,6 +606,14 @@ def run(rank, world_size, args):
|
|||||||
train_dl = librispeech.train_dataloaders()
|
train_dl = librispeech.train_dataloaders()
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
|
|
||||||
|
scan_pessimistic_batches_for_oom(
|
||||||
|
model=model,
|
||||||
|
train_dl=train_dl,
|
||||||
|
optimizer=optimizer,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
@ -646,6 +654,45 @@ def run(rank, world_size, args):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def scan_pessimistic_batches_for_oom(
|
||||||
|
model: nn.Module,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||||
|
params: AttributeDict,
|
||||||
|
):
|
||||||
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
||||||
|
)
|
||||||
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
|
for criterion, cuts in batches.items():
|
||||||
|
batch = train_dl.dataset[cuts]
|
||||||
|
try:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss, _ = compute_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
|
loss.backward()
|
||||||
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
|
optimizer.step()
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "CUDA out of memory" in str(e):
|
||||||
|
logging.error(
|
||||||
|
"Your GPU ran out of memory with the current "
|
||||||
|
"max_duration setting. We recommend decreasing "
|
||||||
|
"max_duration and trying again.\n"
|
||||||
|
f"Failing criterion: {criterion} "
|
||||||
|
f"(={crit_values[criterion]}) ..."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user