run validation only at the end of epoch

This commit is contained in:
wangtiance 2023-01-19 17:32:49 +08:00
parent 623fe22ff1
commit 249b9300e9

View File

@ -395,7 +395,6 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 500,
"reset_interval": 200,
"valid_interval": 9000,
"warm_step": 5000,
"beam_size": 10,
"use_double_scores": True,
@ -927,24 +926,23 @@ def train_one_epoch(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
phone_lexicon=phone_lexicon,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
phone_lexicon=phone_lexicon,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value