from local

This commit is contained in:
dohe0342 2022-12-10 13:08:48 +09:00
parent c3da411351
commit a7c41ca5dd
2 changed files with 0 additions and 4 deletions

View File

@ -736,16 +736,13 @@ def compute_loss(
token_ids=token_ids, token_ids=token_ids,
) )
logging.info('1')
# Works with a BPE model # Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
logging.info('2')
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
ctc_output, ctc_output,
supervision_segments, supervision_segments,
allow_truncate=params.subsampling_factor - 1, allow_truncate=params.subsampling_factor - 1,
) )
logging.info('3')
ctc_loss = k2.ctc_loss( ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -754,7 +751,6 @@ def compute_loss(
reduction="sum", reduction="sum",
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
) )
logging.info('4')
assert ctc_loss.requires_grad == is_training assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss