from local

This commit is contained in:
dohe0342 2022-12-10 11:08:49 +09:00
parent 406a93be59
commit ce8fb6fea0
3 changed files with 33 additions and 31 deletions

Binary file not shown.

View File

@ -719,7 +719,9 @@ def compute_loss(
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
'''
info = MetricsTracker()
if params.ctc_loss_scale > 0:
# Compute ctc loss
# NOTE: We need `encode_supervisions` to sort sequences with
@ -750,10 +752,11 @@ def compute_loss(
)
assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss
'''
info["ctc_loss"] = ctc_loss.detach().cpu().item()
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
@ -762,7 +765,6 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
#info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info