From f6b4c61d992f09c1a359d135aa85ba3e2bd06a26 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Tue, 14 Feb 2023 18:02:48 +0900 Subject: [PATCH] from local --- egs/tedlium2/ASR/conformer_ctc3/.train.py.swp | Bin 61440 -> 61440 bytes egs/tedlium2/ASR/conformer_ctc3/train.py | 26 +++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 09fef4cf4bd26b9870d5f619361e6dcd2441eea6..9b802dc05a0de53df1c71dbd1df4e68fc9357ca2 100644 GIT binary patch delta 462 zcmZ9{JxD@P6ae63YWlREin*x6-69lu2}VT*wUmM6)Kt9cr4WdQN<%*+Gzt-%p`jsB z2gl+KEk#oW4e^HPN2A^z7_Jr4@!1ew__&;V&$;K^ysFQude^G35Q(y}fB>}Y02k2@ z>FCT=x9dW3)9$?LFPH$f&7I2-IuvPdzYjnp0_ho% z1urDWx$rPBhyw9*ZU%;(K->$&fj}$`#BaD581@121RzcUVhtdc0%A!Z<_6;XoD2+y zfOrlN#{qEw5c>nMHW2>-+H#&_Gb77i_Q?e|1SVftBE5Op4F$%{pBH5@P1ah%vw6au z5@yy|J_d$`lQ|!#^NH~@Fnj{?PJ=}FHZK&6=bPxjGI_zf#K{6?ER*@~$xN1g&&y~u I+2H*J07#`gSO5S3 diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index 9ba5ca6b2..f6c1e4fd9 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -589,8 +589,8 @@ def compute_loss( allow_truncate=params.subsampling_factor - 1, ) - dense_fsa_vec2 = [ - k2.DenseFsaVec( + dense_fsa_vec_inter = [ + k2.DenseFsaVec( nnet_output[1][2], supervision_segments, allow_truncate=params.subsampling_factor - 1, @@ -616,19 +616,25 @@ def compute_loss( allow_truncate=params.subsampling_factor - 1, ) ] - ctc_loss = (1-params.interctc_weight)* k2.ctc_loss( + + ctc_loss = k2.ctc_loss( decoding_graph=decoding_graph, dense_fsa_vec=dense_fsa_vec1, output_beam=params.beam_size, reduction=params.reduction, use_double_scores=params.use_double_scores, - ) + params.interctc_weight * k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec2, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) + ) + + inter_ctc_loss = 0 + for fsa in dense_fsa_vec_inter: + inter_ctc_loss += k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=fsa, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + else: dense_fsa_vec = k2.DenseFsaVec(