From 14e088655941110865478cf40049bd82d6708aa6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 15 Aug 2021 11:45:53 +0800 Subject: [PATCH] Minor fixes. --- .../conformer_ctc_madam_no_warmup/train.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py index fe675be01..4ec296646 100755 --- a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py +++ b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py @@ -194,7 +194,10 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, model=model, optimizer=optimizer, scheduler=scheduler, + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, ) keys = [ @@ -512,6 +515,7 @@ def train_one_epoch( params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): if batch_idx == 0: + logging.info("save a batch for OOM handling") # Use this batch to replace the batch that's causing OOM params.saved_batch = batch @@ -597,7 +601,9 @@ def train_one_epoch( params.batch_idx_train, ) tb_writer.add_scalar( - "train/tot_avg_loss", tot_avg_loss, params.batch_idx_train, + "train/tot_avg_loss", + tot_avg_loss, + params.batch_idx_train, ) if batch_idx > 0 and batch_idx % params.reset_interval == 0: tot_loss = 0.0 # sum of losses over all batches @@ -646,6 +652,9 @@ def train_one_epoch( params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss + if "saved_batch" in params: + del params["saved_batch"] + def run(rank, world_size, args): """ @@ -749,10 +758,12 @@ def run(rank, world_size, args): tb_writer=tb_writer, world_size=world_size, ) - del params.saved_batch save_checkpoint( - params=params, model=model, optimizer=optimizer, rank=rank, + params=params, + model=model, + optimizer=optimizer, + rank=rank, ) logging.info("Done!")