diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 318e8b357..09fef4cf4 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index 3b5a3b728..9ba5ca6b2 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -583,18 +583,39 @@ def compute_loss( ) elif params.condition: - dense_fsa_vec1 = k2.DenseFsaVec( + dense_fsa_vec = k2.DenseFsaVec( nnet_output[0], supervision_segments, allow_truncate=params.subsampling_factor - 1, ) - dense_fsa_vec2 = k2.DenseFsaVec( - nnet_output[1][8], - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - + dense_fsa_vec2 = [ + k2.DenseFsaVec( + nnet_output[1][2], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + k2.DenseFsaVec( + nnet_output[1][5], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + k2.DenseFsaVec( + nnet_output[1][8], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + k2.DenseFsaVec( + nnet_output[1][11], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + k2.DenseFsaVec( + nnet_output[1][14], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + ] ctc_loss = (1-params.interctc_weight)* k2.ctc_loss( decoding_graph=decoding_graph, dense_fsa_vec=dense_fsa_vec1,