from local

This commit is contained in:
dohe0342 2023-02-03 01:29:41 +09:00
parent 46ec2c8312
commit ef3051f66e
2 changed files with 4 additions and 2 deletions

View File

@ -460,6 +460,10 @@ def compute_validation_loss(
tot_loss.reduce(loss.device) tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss.device == 0:
wb.log({"valid/loss": loss_value)
if params.cur_epoch < 10: if params.cur_epoch < 10:
params.best_valid_losses[params.cur_epoch] = loss_value params.best_valid_losses[params.cur_epoch] = loss_value
@ -552,8 +556,6 @@ def train_one_epoch(
) )
wb.log({"train/loss": tot_loss}) wb.log({"train/loss": tot_loss})
if batch_idx % params.log_interval == 0:
if tb_writer is not None: if tb_writer is not None:
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train