From 4483c6e700975746061c471ac90033f0a2c54e49 Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 21:52:59 +0800 Subject: [PATCH] tensorboard should work properly --- egs/libritts/CODEC/encodec/train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index e207b12f7..73d698008 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -601,7 +601,7 @@ def compute_validation_loss( # used to summary the stats over iterations tot_loss = MetricsTracker() - returned_sample = None + returned_sample = (None, None) with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): @@ -634,7 +634,7 @@ def compute_validation_loss( speech_lengths=audio_lens, global_step=params.batch_idx_train, forward_generator=True, - return_sample=True, + return_sample=False, ) assert loss_g.requires_grad is False for k, v in stats_g.items(): @@ -649,8 +649,6 @@ def compute_validation_loss( inner_model = model.module if isinstance(model, DDP) else model audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) returned_sample = (audio_pred, audio) - else: - returned_sample = (None, None) if world_size > 1: tot_loss.reduce(device)