diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 38d6a2df6..e29f93389 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index ffbe41c7c..0557a7c20 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -543,20 +543,36 @@ def compute_loss( #token_ids = convert_texts_into_ids(texts, graph_compiler.sp) token_ids = graph_compiler.texts_to_ids(texts) decoding_graph = graph_compiler.compile(token_ids) + + if params.interctc: + dense_fsa_vec1 = k2.DenseFsaVec( + nnet_output[8], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) + else: + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) if params.att_rate > 0.0: with torch.set_grad_enabled(is_training):