diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 54f9872a7..37ea317bb 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp index 53fadf4d3..a052c8f62 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index dc4b0348d..74490021a 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -638,6 +638,7 @@ def compute_loss( ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss else: + ''' dense_fsa_vec = k2.DenseFsaVec( nnet_output, supervision_segments, @@ -651,6 +652,60 @@ def compute_loss( reduction=params.reduction, use_double_scores=params.use_double_scores, ) + ''' + dense_fsa_vec = k2.DenseFsaVec( + nnet_output[0], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + dense_fsa_vec_inter = [ + k2.DenseFsaVec( + nnet_output[1][2], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + k2.DenseFsaVec( + nnet_output[1][5], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + k2.DenseFsaVec( + nnet_output[1][8], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + k2.DenseFsaVec( + nnet_output[1][11], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + k2.DenseFsaVec( + nnet_output[1][14], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + ] + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + inter_ctc_loss = 0 + for fsa in dense_fsa_vec_inter: + inter_ctc_loss += k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=fsa, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss if params.att_rate > 0.0: with torch.set_grad_enabled(is_training):