Refactor OOM batch scanning into a local function

This commit is contained in:
Piotr Żelasko 2021-10-18 09:53:04 -04:00
parent d509d58f30
commit 6fbd7a287c

View File

@ -503,9 +503,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -606,47 +604,20 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
logging.info(
f"* criterion: {criterion} (={crit_values[criterion]}) ..."
)
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
scan_pessimistic_batches_for_oom(
model=model,
batch=batch,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
is_training=True,
params=params,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
logging.info("OK!")
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again."
)
raise
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
@ -679,6 +650,44 @@ def run(rank, world_size, args):
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} (={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)