from local

This commit is contained in:
dohe0342 2023-02-14 18:00:42 +09:00
parent 0ca6f2b2cf
commit 1bd60745ee
2 changed files with 28 additions and 7 deletions

View File

@ -583,18 +583,39 @@ def compute_loss(
) )
elif params.condition: elif params.condition:
dense_fsa_vec1 = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output[0], nnet_output[0],
supervision_segments, supervision_segments,
allow_truncate=params.subsampling_factor - 1, allow_truncate=params.subsampling_factor - 1,
) )
dense_fsa_vec2 = k2.DenseFsaVec( dense_fsa_vec2 = [
nnet_output[1][8], k2.DenseFsaVec(
supervision_segments, nnet_output[1][2],
allow_truncate=params.subsampling_factor - 1, supervision_segments,
) allow_truncate=params.subsampling_factor - 1,
),
k2.DenseFsaVec(
nnet_output[1][5],
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
),
k2.DenseFsaVec(
nnet_output[1][8],
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
),
k2.DenseFsaVec(
nnet_output[1][11],
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
),
k2.DenseFsaVec(
nnet_output[1][14],
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
]
ctc_loss = (1-params.interctc_weight)* k2.ctc_loss( ctc_loss = (1-params.interctc_weight)* k2.ctc_loss(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec1, dense_fsa_vec=dense_fsa_vec1,