mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
406a93be59
commit
ce8fb6fea0
Binary file not shown.
Binary file not shown.
@ -719,7 +719,9 @@ def compute_loss(
|
||||
|
||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
'''
|
||||
info = MetricsTracker()
|
||||
|
||||
if params.ctc_loss_scale > 0:
|
||||
# Compute ctc loss
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
@ -750,10 +752,11 @@ def compute_loss(
|
||||
)
|
||||
assert ctc_loss.requires_grad == is_training
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
'''
|
||||
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
@ -762,7 +765,6 @@ def compute_loss(
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
#info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user