From 249b9300e9b8c39aaf90237c3f72b3108bfb0f13 Mon Sep 17 00:00:00 2001 From: wangtiance Date: Thu, 19 Jan 2023 17:32:49 +0800 Subject: [PATCH] run validation only at the end of epoch --- .../ASR/tiny_transducer_ctc/train.py | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 34f7178e9..25adaff58 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -395,7 +395,6 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 500, "reset_interval": 200, - "valid_interval": 9000, "warm_step": 5000, "beam_size": 10, "use_double_scores": True, @@ -927,24 +926,23 @@ def train_one_epoch( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) - if batch_idx > 0 and batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - phone_lexicon=phone_lexicon, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value