diff --git a/egs/tedlium2/ASR/conformer_ctc3/.decode.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.decode.py.swp index 03ea403b1..2f9ba22fe 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.decode.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.decode.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index ea41fd560..547e2af73 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index 991b5c0e7..e41ffeebd 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -688,62 +688,6 @@ def compute_loss( use_double_scores=params.use_double_scores, ) - ''' - dense_fsa_vec = k2.DenseFsaVec( - nnet_output[0], - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - dense_fsa_vec_inter = [ - #k2.DenseFsaVec( - # nnet_output[1][2], - # supervision_segments, - # allow_truncate=params.subsampling_factor - 1, - #), - #k2.DenseFsaVec( - # nnet_output[1][5], - # supervision_segments, - # allow_truncate=params.subsampling_factor - 1, - #), - k2.DenseFsaVec( - nnet_output[1][8], - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ), - #k2.DenseFsaVec( - # nnet_output[1][11], - # supervision_segments, - # allow_truncate=params.subsampling_factor - 1, - #), - #k2.DenseFsaVec( - # nnet_output[1][14], - # supervision_segments, - # allow_truncate=params.subsampling_factor - 1, - #) - ] - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - inter_ctc_loss = 0 - for fsa in dense_fsa_vec_inter: - inter_ctc_loss += k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=fsa, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss - ''' - if params.att_rate > 0.0: with torch.set_grad_enabled(is_training): mmodel = model.module if hasattr(model, "module") else model