from local

This commit is contained in:
dohe0342 2023-02-17 10:16:13 +09:00
parent a47bd6a4a5
commit 2d998e537b
3 changed files with 2 additions and 2 deletions

View File

@ -559,7 +559,7 @@ def compute_loss(
token_ids = graph_compiler.texts_to_ids(texts) token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids) decoding_graph = graph_compiler.compile(token_ids)
if params.interctc: if params.interctc and params.group_num == 0:
dense_fsa_vec1 = k2.DenseFsaVec( dense_fsa_vec1 = k2.DenseFsaVec(
nnet_output[0], nnet_output[0],
supervision_segments, supervision_segments,
@ -641,7 +641,7 @@ def compute_loss(
ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss
if (params.condition and params.group_num > 0) or (params.group_num > 0): if (params.condition and params.group_num > 0) or (params.interctc and params.group_num > 0) or (params.group_num > 0):
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output[0], nnet_output[0],
supervision_segments, supervision_segments,