diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 7280dd93b..38d121c97 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index 0e7e74d2b..819d55aa2 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -653,6 +653,14 @@ def compute_loss( for i in [2,5,8,11,14] ] + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + inter_ctc_loss = 0 for fsa_vec_inter in dense_fsa_vec_inter: inter_ctc_loss += k2.ctc_loss( @@ -663,21 +671,8 @@ def compute_loss( use_double_scores=params.use_double_scores, ) - - ctc_loss = (1-params.interctc_weight) * k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) + params.interctc_weight * k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec_inter, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - + ctc_loss = (1-params.interctc_weight) * ctc_loss + + params.interctc_weight * inter_ctc_weight else: dense_fsa_vec = k2.DenseFsaVec( nnet_output,