mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Refactor OOM batch scanning into a local function
This commit is contained in:
parent
d509d58f30
commit
6fbd7a287c
@ -503,9 +503,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -606,47 +604,20 @@ def run(rank, world_size, args):
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
logging.info(
|
||||
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
||||
)
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
logging.info(
|
||||
f"* criterion: {criterion} (={crit_values[criterion]}) ..."
|
||||
)
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
optimizer.zero_grad()
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
batch=batch,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
params=params,
|
||||
)
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
logging.info("OK!")
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
"max_duration setting. We recommend decreasing "
|
||||
"max_duration and trying again."
|
||||
)
|
||||
raise
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
@ -679,6 +650,44 @@ def run(rank, world_size, args):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
logging.info(
|
||||
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
||||
)
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
optimizer.zero_grad()
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
"max_duration setting. We recommend decreasing "
|
||||
"max_duration and trying again.\n"
|
||||
f"Failing criterion: {criterion} (={crit_values[criterion]}) ..."
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
|
Loading…
x
Reference in New Issue
Block a user