mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Refactor OOM batch scanning into a local function
This commit is contained in:
parent
d509d58f30
commit
6fbd7a287c
@ -503,9 +503,7 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -606,47 +604,20 @@ def run(rank, world_size, args):
|
|||||||
train_dl = librispeech.train_dataloaders()
|
train_dl = librispeech.train_dataloaders()
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
|
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
scan_pessimistic_batches_for_oom(
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
|
||||||
)
|
|
||||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
|
||||||
for criterion, cuts in batches.items():
|
|
||||||
logging.info(
|
|
||||||
f"* criterion: {criterion} (={crit_values[criterion]}) ..."
|
|
||||||
)
|
|
||||||
batch = train_dl.dataset[cuts]
|
|
||||||
try:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss, _ = compute_loss(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
train_dl=train_dl,
|
||||||
|
optimizer=optimizer,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
params=params,
|
||||||
)
|
)
|
||||||
loss.backward()
|
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
|
||||||
optimizer.step()
|
|
||||||
logging.info("OK!")
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "CUDA out of memory" in str(e):
|
|
||||||
logging.error(
|
|
||||||
"Your GPU ran out of memory with the current "
|
|
||||||
"max_duration setting. We recommend decreasing "
|
|
||||||
"max_duration and trying again."
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -679,6 +650,44 @@ def run(rank, world_size, args):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def scan_pessimistic_batches_for_oom(
|
||||||
|
model: nn.Module,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||||
|
params: AttributeDict,
|
||||||
|
):
|
||||||
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
||||||
|
)
|
||||||
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
|
for criterion, cuts in batches.items():
|
||||||
|
batch = train_dl.dataset[cuts]
|
||||||
|
try:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss, _ = compute_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
|
loss.backward()
|
||||||
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
|
optimizer.step()
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "CUDA out of memory" in str(e):
|
||||||
|
logging.error(
|
||||||
|
"Your GPU ran out of memory with the current "
|
||||||
|
"max_duration setting. We recommend decreasing "
|
||||||
|
"max_duration and trying again.\n"
|
||||||
|
f"Failing criterion: {criterion} (={crit_values[criterion]}) ..."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user