run validation only at the end of epoch

This commit is contained in:
wangtiance 2023-01-19 17:32:49 +08:00
parent 623fe22ff1
commit 249b9300e9

View File

@ -395,7 +395,6 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 500, "log_interval": 500,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 9000,
"warm_step": 5000, "warm_step": 5000,
"beam_size": 10, "beam_size": 10,
"use_double_scores": True, "use_double_scores": True,
@ -927,24 +926,23 @@ def train_one_epoch(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale", cur_grad_scale, params.batch_idx_train
) )
if batch_idx > 0 and batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss")
logging.info("Computing validation loss") valid_info = compute_validation_loss(
valid_info = compute_validation_loss( params=params,
params=params, model=model,
model=model, sp=sp,
sp=sp, phone_lexicon=phone_lexicon,
phone_lexicon=phone_lexicon, valid_dl=valid_dl,
valid_dl=valid_dl, world_size=world_size,
world_size=world_size, )
) model.train()
model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(
logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") if tb_writer is not None:
if tb_writer is not None: valid_info.write_summary(
valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train
tb_writer, "train/valid_", params.batch_idx_train )
)
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value