From e3ef194d43935544001c8e0ccccf894cd8bd3191 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Wed, 15 Feb 2023 15:13:10 +0900 Subject: [PATCH] from local --- egs/tedlium2/ASR/conformer_ctc3/.train.py.swp | Bin 69632 -> 69632 bytes egs/tedlium2/ASR/conformer_ctc3/train.py | 25 +++++++----------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 7280dd93b403607803f7852ec5fb374ec9273fb5..38d121c97578d8468bf4a252d421441bfe672503 100644 GIT binary patch delta 434 zcmXxgJxD@P6bJBoQz6Nx?-}MGgj<_AL}6J8(bu36VHg#PBIc4%f=h%YK@DPcC=CZK zY-q3&NrDKPs?n{bt=bCv7+Q+@zrq7QxaZ!(x%b?7UK!6T4Swt7sMBn-SQ$cO>@@Zq z7%mxJbX-dcvB*7C5{+t|e2Nh1P5dn1O7>;J#u+R)GzehIaT+ zGWG~}a1N_50cNNv7`ucdIH3o+;Tvt`)>Sxw2smhMYkCP6B{m16Ok8@w=a?jOGNjIg?nzl4O6CRCd452V)ZdcZO zi=nE_WpXQd9;c$B4q~GfolZV*{SUY^mz^x=NoiPbqI=iZmN-)R%10Yx zH;{&T=z|)LJU|Mzz%E&Z5vb#eGTeay8!!L@+~9&2yrckols=v~z3$*qOZyFQ+PU!4 z?d&AQDhUebf<~=}5;=(ut4Vsj3XuLOk#6*poYlyjl4v|T)XHD`5z;d4lscCwm?_)W zSkaTx^4yt^v2U>35l*)awEYhEU_uhcAOgeSg=c(O9#XIXDtN(TE!GP@iv0+>eDveZ NCGH73VKtjv`2*fXVP*gT diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index 0e7e74d2b..819d55aa2 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -653,6 +653,14 @@ def compute_loss( for i in [2,5,8,11,14] ] + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + inter_ctc_loss = 0 for fsa_vec_inter in dense_fsa_vec_inter: inter_ctc_loss += k2.ctc_loss( @@ -663,21 +671,8 @@ def compute_loss( use_double_scores=params.use_double_scores, ) - - ctc_loss = (1-params.interctc_weight) * k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) + params.interctc_weight * k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec_inter, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - + ctc_loss = (1-params.interctc_weight) * ctc_loss + + params.interctc_weight * inter_ctc_weight else: dense_fsa_vec = k2.DenseFsaVec( nnet_output,