diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 09fef4cf4..9b802dc05 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index 9ba5ca6b2..f6c1e4fd9 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -589,8 +589,8 @@ def compute_loss( allow_truncate=params.subsampling_factor - 1, ) - dense_fsa_vec2 = [ - k2.DenseFsaVec( + dense_fsa_vec_inter = [ + k2.DenseFsaVec( nnet_output[1][2], supervision_segments, allow_truncate=params.subsampling_factor - 1, @@ -616,19 +616,25 @@ def compute_loss( allow_truncate=params.subsampling_factor - 1, ) ] - ctc_loss = (1-params.interctc_weight)* k2.ctc_loss( + + ctc_loss = k2.ctc_loss( decoding_graph=decoding_graph, dense_fsa_vec=dense_fsa_vec1, output_beam=params.beam_size, reduction=params.reduction, use_double_scores=params.use_double_scores, - ) + params.interctc_weight * k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec2, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) + ) + + inter_ctc_loss = 0 + for fsa in dense_fsa_vec_inter: + inter_ctc_loss += k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=fsa, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + else: dense_fsa_vec = k2.DenseFsaVec(