from local

This commit is contained in:
dohe0342 2022-12-10 13:06:29 +09:00
parent 63e2c16dc6
commit c3da411351
3 changed files with 5 additions and 1 deletions

View File

@ -735,14 +735,17 @@ def compute_loss(
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
token_ids=token_ids, token_ids=token_ids,
) )
logging.info('1')
# Works with a BPE model # Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
logging.info('2')
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
ctc_output, ctc_output,
supervision_segments, supervision_segments,
allow_truncate=params.subsampling_factor - 1, allow_truncate=params.subsampling_factor - 1,
) )
logging.info('3')
ctc_loss = k2.ctc_loss( ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -751,6 +754,7 @@ def compute_loss(
reduction="sum", reduction="sum",
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
) )
logging.info('4')
assert ctc_loss.requires_grad == is_training assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss

Binary file not shown.