from local

This commit is contained in:
dohe0342 2023-02-17 10:47:03 +09:00
parent 988a165fe9
commit 8606c9a55d
2 changed files with 4 additions and 16 deletions

View File

@ -677,22 +677,10 @@ def compute_loss(
ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss
if not params.interctc and not params.condition and params.group_num > 0: if not params.interctc and not params.condition:
dense_fsa_vec = k2.DenseFsaVec( if type(nnet_output) == tuple:
nnet_output[0], nnet_output = nnet_output[0]
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if not params.interctc and not params.condition and params.group_num == 0:
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output,
supervision_segments, supervision_segments,