mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
tensorboard should work properly
This commit is contained in:
parent
8da57a0449
commit
4483c6e700
@ -601,7 +601,7 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
# used to summary the stats over iterations
|
# used to summary the stats over iterations
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
returned_sample = None
|
returned_sample = (None, None)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
@ -634,7 +634,7 @@ def compute_validation_loss(
|
|||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
global_step=params.batch_idx_train,
|
global_step=params.batch_idx_train,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
return_sample=True,
|
return_sample=False,
|
||||||
)
|
)
|
||||||
assert loss_g.requires_grad is False
|
assert loss_g.requires_grad is False
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
@ -649,8 +649,6 @@ def compute_validation_loss(
|
|||||||
inner_model = model.module if isinstance(model, DDP) else model
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
|
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
|
||||||
returned_sample = (audio_pred, audio)
|
returned_sample = (audio_pred, audio)
|
||||||
else:
|
|
||||||
returned_sample = (None, None)
|
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(device)
|
tot_loss.reduce(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user