From 72c0220830c8b3ac864f59a12961ad4fd57d84e9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 15 Aug 2021 09:52:17 +0800 Subject: [PATCH] Fix oom handling. --- .../conformer_ctc_madam_no_warmup/train.py | 161 ++++++++++-------- 1 file changed, 88 insertions(+), 73 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py index 331e9f171..fe675be01 100755 --- a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py +++ b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py @@ -3,6 +3,7 @@ # This is just at the very beginning ... import argparse +import gc import logging from pathlib import Path from shutil import copyfile @@ -193,10 +194,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -281,58 +279,75 @@ def compute_loss_impl( assert feature.ndim == 3 feature = feature.to(device) - supervisions = batch["supervisions"] - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + try: - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate != 0.0: + supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): - if hasattr(model, "module"): - att_loss = model.module.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) + nnet_output, encoder_memory, memory_mask = model( + feature, supervisions + ) + # nnet_output is [N, T, C] + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + if hasattr(model, "module"): + att_loss = model.module.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + else: + att_loss = model.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = ( + 1.0 - params.att_rate + ) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + except RuntimeError as ex: + try: + del nnet_output + del encoder_memory + del dense_fsa_vec + del ctc_loss + del att_loss + del loss + except NameError as ne: + pass + raise ex # train_frames and valid_frames are used for printing. if is_training: @@ -364,6 +379,7 @@ def compute_loss( if "out of memory" not in str(ex): raise ex + logging.exception(ex) s = f"\nCaught exception: {str(ex)}\n" total_duration = 0.0 max_cut_duration = 0.0 @@ -375,19 +391,23 @@ def compute_loss( s += f" max duration: {max_cut_duration:.3f} s \n" logging.info(s) - # see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283 - for p in model.parameters(): - if p.grad is not None: - del p.grad # free some memory - torch.cuda.empty_cache() + # see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283 + for p in model.parameters(): + if p.grad is not None: + del p.grad # free some memory - return compute_loss_impl( - params=params, - model=model, - batch=params.saved_batch, - graph_compiler=graph_compiler, - is_training=is_training, - ) + torch.cuda.empty_cache() + + gc.collect() + + # See https://github.com/pytorch/pytorch/issues/18853#issuecomment-583779161 + return compute_loss_impl( + params=params, + model=model, + batch=params.saved_batch, + graph_compiler=graph_compiler, + is_training=is_training, + ) def compute_validation_loss( @@ -577,9 +597,7 @@ def train_one_epoch( params.batch_idx_train, ) tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, + "train/tot_avg_loss", tot_avg_loss, params.batch_idx_train, ) if batch_idx > 0 and batch_idx % params.reset_interval == 0: tot_loss = 0.0 # sum of losses over all batches @@ -734,10 +752,7 @@ def run(rank, world_size, args): del params.saved_batch save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, + params=params, model=model, optimizer=optimizer, rank=rank, ) logging.info("Done!")