Fix oom handling.

This commit is contained in:
Fangjun Kuang 2021-08-15 09:52:17 +08:00
parent 36ac512d00
commit 72c0220830

View File

@ -3,6 +3,7 @@
# This is just at the very beginning ...
import argparse
import gc
import logging
from pathlib import Path
from shutil import copyfile
@ -193,10 +194,7 @@ def load_checkpoint_if_available(
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
filename, model=model, optimizer=optimizer, scheduler=scheduler,
)
keys = [
@ -281,58 +279,75 @@ def compute_loss_impl(
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
try:
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
nnet_output, encoder_memory, memory_mask = model(
feature, supervisions
)
# nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (
1.0 - params.att_rate
) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
except RuntimeError as ex:
try:
del nnet_output
del encoder_memory
del dense_fsa_vec
del ctc_loss
del att_loss
del loss
except NameError as ne:
pass
raise ex
# train_frames and valid_frames are used for printing.
if is_training:
@ -364,6 +379,7 @@ def compute_loss(
if "out of memory" not in str(ex):
raise ex
logging.exception(ex)
s = f"\nCaught exception: {str(ex)}\n"
total_duration = 0.0
max_cut_duration = 0.0
@ -375,19 +391,23 @@ def compute_loss(
s += f" max duration: {max_cut_duration:.3f} s \n"
logging.info(s)
# see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283
for p in model.parameters():
if p.grad is not None:
del p.grad # free some memory
torch.cuda.empty_cache()
# see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283
for p in model.parameters():
if p.grad is not None:
del p.grad # free some memory
return compute_loss_impl(
params=params,
model=model,
batch=params.saved_batch,
graph_compiler=graph_compiler,
is_training=is_training,
)
torch.cuda.empty_cache()
gc.collect()
# See https://github.com/pytorch/pytorch/issues/18853#issuecomment-583779161
return compute_loss_impl(
params=params,
model=model,
batch=params.saved_batch,
graph_compiler=graph_compiler,
is_training=is_training,
)
def compute_validation_loss(
@ -577,9 +597,7 @@ def train_one_epoch(
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
"train/tot_avg_loss", tot_avg_loss, params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
@ -734,10 +752,7 @@ def run(rank, world_size, args):
del params.saved_batch
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
params=params, model=model, optimizer=optimizer, rank=rank,
)
logging.info("Done!")