Fix oom handling.

This commit is contained in:
Fangjun Kuang 2021-08-15 09:52:17 +08:00
parent 36ac512d00
commit 72c0220830

View File

@ -3,6 +3,7 @@
# This is just at the very beginning ... # This is just at the very beginning ...
import argparse import argparse
import gc
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -193,10 +194,7 @@ def load_checkpoint_if_available(
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename, model=model, optimizer=optimizer, scheduler=scheduler,
model=model,
optimizer=optimizer,
scheduler=scheduler,
) )
keys = [ keys = [
@ -281,58 +279,75 @@ def compute_loss_impl(
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] try:
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with supervisions = batch["supervisions"]
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if hasattr(model, "module"): nnet_output, encoder_memory, memory_mask = model(
att_loss = model.module.decoder_forward( feature, supervisions
encoder_memory, )
memory_mask, # nnet_output is [N, T, C]
token_ids=token_ids,
sos_id=graph_compiler.sos_id, # NOTE: We need `encode_supervisions` to sort sequences with
eos_id=graph_compiler.eos_id, # different duration in decreasing order, required by
) # `k2.intersect_dense` called in `k2.ctc_loss`
else: supervision_segments, texts = encode_supervisions(
att_loss = model.decoder_forward( supervisions, subsampling_factor=params.subsampling_factor
encoder_memory, )
memory_mask,
token_ids=token_ids, token_ids = graph_compiler.texts_to_ids(texts)
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id, decoding_graph = graph_compiler.compile(token_ids)
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss dense_fsa_vec = k2.DenseFsaVec(
else: nnet_output,
loss = ctc_loss supervision_segments,
att_loss = torch.tensor([0]) allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (
1.0 - params.att_rate
) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
except RuntimeError as ex:
try:
del nnet_output
del encoder_memory
del dense_fsa_vec
del ctc_loss
del att_loss
del loss
except NameError as ne:
pass
raise ex
# train_frames and valid_frames are used for printing. # train_frames and valid_frames are used for printing.
if is_training: if is_training:
@ -364,6 +379,7 @@ def compute_loss(
if "out of memory" not in str(ex): if "out of memory" not in str(ex):
raise ex raise ex
logging.exception(ex)
s = f"\nCaught exception: {str(ex)}\n" s = f"\nCaught exception: {str(ex)}\n"
total_duration = 0.0 total_duration = 0.0
max_cut_duration = 0.0 max_cut_duration = 0.0
@ -375,19 +391,23 @@ def compute_loss(
s += f" max duration: {max_cut_duration:.3f} s \n" s += f" max duration: {max_cut_duration:.3f} s \n"
logging.info(s) logging.info(s)
# see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283 # see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283
for p in model.parameters(): for p in model.parameters():
if p.grad is not None: if p.grad is not None:
del p.grad # free some memory del p.grad # free some memory
torch.cuda.empty_cache()
return compute_loss_impl( torch.cuda.empty_cache()
params=params,
model=model, gc.collect()
batch=params.saved_batch,
graph_compiler=graph_compiler, # See https://github.com/pytorch/pytorch/issues/18853#issuecomment-583779161
is_training=is_training, return compute_loss_impl(
) params=params,
model=model,
batch=params.saved_batch,
graph_compiler=graph_compiler,
is_training=is_training,
)
def compute_validation_loss( def compute_validation_loss(
@ -577,9 +597,7 @@ def train_one_epoch(
params.batch_idx_train, params.batch_idx_train,
) )
tb_writer.add_scalar( tb_writer.add_scalar(
"train/tot_avg_loss", "train/tot_avg_loss", tot_avg_loss, params.batch_idx_train,
tot_avg_loss,
params.batch_idx_train,
) )
if batch_idx > 0 and batch_idx % params.reset_interval == 0: if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches tot_loss = 0.0 # sum of losses over all batches
@ -734,10 +752,7 @@ def run(rank, world_size, args):
del params.saved_batch del params.saved_batch
save_checkpoint( save_checkpoint(
params=params, params=params, model=model, optimizer=optimizer, rank=rank,
model=model,
optimizer=optimizer,
rank=rank,
) )
logging.info("Done!") logging.info("Done!")