tensorboard should work properly

This commit is contained in:
JinZr 2024-09-06 21:52:59 +08:00
parent 8da57a0449
commit 4483c6e700

View File

@ -601,7 +601,7 @@ def compute_validation_loss(
# used to summary the stats over iterations
tot_loss = MetricsTracker()
returned_sample = None
returned_sample = (None, None)
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
@ -634,7 +634,7 @@ def compute_validation_loss(
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
forward_generator=True,
return_sample=True,
return_sample=False,
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
@ -649,8 +649,6 @@ def compute_validation_loss(
inner_model = model.module if isinstance(model, DDP) else model
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
returned_sample = (audio_pred, audio)
else:
returned_sample = (None, None)
if world_size > 1:
tot_loss.reduce(device)