from local

This commit is contained in:
dohe0342 2022-12-10 11:08:49 +09:00
parent 406a93be59
commit ce8fb6fea0
3 changed files with 33 additions and 31 deletions

Binary file not shown.

View File

@ -719,41 +719,44 @@ def compute_loss(
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
''' info = MetricsTracker()
# Compute ctc loss
if params.ctc_loss_scale > 0:
# Compute ctc loss
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss` # `k2.intersect_dense` called in `k2.ctc_loss`
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions( supervision_segments, token_ids = encode_supervisions(
supervisions, supervisions,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
token_ids=token_ids, token_ids=token_ids,
)
# Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
) )
# Works with a BPE model ctc_loss = k2.ctc_loss(
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) decoding_graph=decoding_graph,
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec=dense_fsa_vec,
ctc_output, output_beam=params.beam_size,
supervision_segments, reduction="sum",
allow_truncate=params.subsampling_factor - 1, use_double_scores=params.use_double_scores,
) )
assert ctc_loss.requires_grad == is_training
ctc_loss = k2.ctc_loss( loss += params.ctc_loss_scale * ctc_loss
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec, info["ctc_loss"] = ctc_loss.detach().cpu().item()
output_beam=params.beam_size,
reduction="sum",
use_double_scores=params.use_double_scores,
)
assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss
'''
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
@ -762,7 +765,6 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
#info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info return loss, info