mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 02:06:13 +00:00
tensorboard should work properly
This commit is contained in:
parent
8da57a0449
commit
4483c6e700
@ -601,7 +601,7 @@ def compute_validation_loss(
|
||||
|
||||
# used to summary the stats over iterations
|
||||
tot_loss = MetricsTracker()
|
||||
returned_sample = None
|
||||
returned_sample = (None, None)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
@ -634,7 +634,7 @@ def compute_validation_loss(
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
forward_generator=True,
|
||||
return_sample=True,
|
||||
return_sample=False,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
@ -649,8 +649,6 @@ def compute_validation_loss(
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
|
||||
returned_sample = (audio_pred, audio)
|
||||
else:
|
||||
returned_sample = (None, None)
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user