mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
run validation only at the end of epoch
This commit is contained in:
parent
623fe22ff1
commit
249b9300e9
@ -395,7 +395,6 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 500,
|
"log_interval": 500,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 9000,
|
|
||||||
"warm_step": 5000,
|
"warm_step": 5000,
|
||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
@ -927,24 +926,23 @@ def train_one_epoch(
|
|||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
logging.info("Computing validation loss")
|
||||||
logging.info("Computing validation loss")
|
valid_info = compute_validation_loss(
|
||||||
valid_info = compute_validation_loss(
|
params=params,
|
||||||
params=params,
|
model=model,
|
||||||
model=model,
|
sp=sp,
|
||||||
sp=sp,
|
phone_lexicon=phone_lexicon,
|
||||||
phone_lexicon=phone_lexicon,
|
valid_dl=valid_dl,
|
||||||
valid_dl=valid_dl,
|
world_size=world_size,
|
||||||
world_size=world_size,
|
)
|
||||||
)
|
model.train()
|
||||||
model.train()
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(
|
||||||
logging.info(
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
if tb_writer is not None:
|
||||||
if tb_writer is not None:
|
valid_info.write_summary(
|
||||||
valid_info.write_summary(
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
)
|
||||||
)
|
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user