Fix oom handling.

This commit is contained in:
Fangjun Kuang 2021-08-15 09:52:17 +08:00
parent 36ac512d00
commit 72c0220830

View File

@ -3,6 +3,7 @@
# This is just at the very beginning ... # This is just at the very beginning ...
import argparse import argparse
import gc
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -193,10 +194,7 @@ def load_checkpoint_if_available(
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename, model=model, optimizer=optimizer, scheduler=scheduler,
model=model,
optimizer=optimizer,
scheduler=scheduler,
) )
keys = [ keys = [
@ -281,9 +279,13 @@ def compute_loss_impl(
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
try:
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(
feature, supervisions
)
# nnet_output is [N, T, C] # nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
@ -329,10 +331,23 @@ def compute_loss_impl(
sos_id=graph_compiler.sos_id, sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id, eos_id=graph_compiler.eos_id,
) )
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss loss = (
1.0 - params.att_rate
) * ctc_loss + params.att_rate * att_loss
else: else:
loss = ctc_loss loss = ctc_loss
att_loss = torch.tensor([0]) att_loss = torch.tensor([0])
except RuntimeError as ex:
try:
del nnet_output
del encoder_memory
del dense_fsa_vec
del ctc_loss
del att_loss
del loss
except NameError as ne:
pass
raise ex
# train_frames and valid_frames are used for printing. # train_frames and valid_frames are used for printing.
if is_training: if is_training:
@ -364,6 +379,7 @@ def compute_loss(
if "out of memory" not in str(ex): if "out of memory" not in str(ex):
raise ex raise ex
logging.exception(ex)
s = f"\nCaught exception: {str(ex)}\n" s = f"\nCaught exception: {str(ex)}\n"
total_duration = 0.0 total_duration = 0.0
max_cut_duration = 0.0 max_cut_duration = 0.0
@ -379,8 +395,12 @@ def compute_loss(
for p in model.parameters(): for p in model.parameters():
if p.grad is not None: if p.grad is not None:
del p.grad # free some memory del p.grad # free some memory
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
# See https://github.com/pytorch/pytorch/issues/18853#issuecomment-583779161
return compute_loss_impl( return compute_loss_impl(
params=params, params=params,
model=model, model=model,
@ -577,9 +597,7 @@ def train_one_epoch(
params.batch_idx_train, params.batch_idx_train,
) )
tb_writer.add_scalar( tb_writer.add_scalar(
"train/tot_avg_loss", "train/tot_avg_loss", tot_avg_loss, params.batch_idx_train,
tot_avg_loss,
params.batch_idx_train,
) )
if batch_idx > 0 and batch_idx % params.reset_interval == 0: if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches tot_loss = 0.0 # sum of losses over all batches
@ -734,10 +752,7 @@ def run(rank, world_size, args):
del params.saved_batch del params.saved_batch
save_checkpoint( save_checkpoint(
params=params, params=params, model=model, optimizer=optimizer, rank=rank,
model=model,
optimizer=optimizer,
rank=rank,
) )
logging.info("Done!") logging.info("Done!")