diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 647319415..318e8b357 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index b4a60e82e..3b5a3b728 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -582,6 +582,32 @@ def compute_loss( use_double_scores=params.use_double_scores, ) + elif params.condition: + dense_fsa_vec1 = k2.DenseFsaVec( + nnet_output[0], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + dense_fsa_vec2 = k2.DenseFsaVec( + nnet_output[1][8], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = (1-params.interctc_weight)* k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec1, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + params.interctc_weight * k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec2, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) else: dense_fsa_vec = k2.DenseFsaVec(