from local

This commit is contained in:
dohe0342 2023-02-14 13:00:21 +09:00
parent a8e1b34735
commit 326a55eebc
2 changed files with 28 additions and 12 deletions

View File

@ -543,20 +543,36 @@ def compute_loss(
#token_ids = convert_texts_into_ids(texts, graph_compiler.sp) #token_ids = convert_texts_into_ids(texts, graph_compiler.sp)
token_ids = graph_compiler.texts_to_ids(texts) token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids) decoding_graph = graph_compiler.compile(token_ids)
if params.interctc:
dense_fsa_vec1 = k2.DenseFsaVec(
nnet_output[8],
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
dense_fsa_vec = k2.DenseFsaVec( ctc_loss = k2.ctc_loss(
nnet_output, decoding_graph=decoding_graph,
supervision_segments, dense_fsa_vec=dense_fsa_vec,
allow_truncate=params.subsampling_factor - 1, output_beam=params.beam_size,
) reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
ctc_loss = k2.ctc_loss( else:
decoding_graph=decoding_graph, dense_fsa_vec = k2.DenseFsaVec(
dense_fsa_vec=dense_fsa_vec, nnet_output,
output_beam=params.beam_size, supervision_segments,
reduction=params.reduction, allow_truncate=params.subsampling_factor - 1,
use_double_scores=params.use_double_scores, )
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate > 0.0: if params.att_rate > 0.0:
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):