From 0ca6f2b2cf9674211369c3ab3890a5ce889b54e2 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Tue, 14 Feb 2023 17:59:00 +0900 Subject: [PATCH] from local --- egs/tedlium2/ASR/conformer_ctc3/.train.py.swp | Bin 53248 -> 61440 bytes egs/tedlium2/ASR/conformer_ctc3/train.py | 26 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 647319415a4a2eb8ec56f4743a085b32905be520..318e8b357d5af4851458575b905cc0491dcc8e6f 100644 GIT binary patch delta 457 zcmX}oze@sP7zgm@^`sxDcV}P?4&HufA`x_35R`?qCDag?7D%O%(G+?^t3kYLgGmkz z4faMsM1ey~P5BQr6a?WQ?$P&6&j&ud_uTv3`#kSCHO2W*LR;~*WXhO}#R;JkO{ab` zy=JJRu>e=}YnJ9seJY{KUBq~Qp%qX6ferG&U=0uV#e+Vg-56dcuz{kstjFnT5v*-8 zQ>HI)2D=c22u#2i%3E+QKVu}W8YVLhL~h3@+6VT2NhUJ3_lB=xnL<9BJIodGmgv^C zXvNMF9r=iUK{j>=YiUHcAfN;bFbUtd;UhS(3}NWtR?om84M|u63smU(iC&=!6)3|x zn0`{wSTOIgj~5eOe!4iu-!3EK{lZ|PSHJgIa=bEHJb@g~{jY6|B1fyjD$K(SjKb%C Py|9BPs(Y%qYs|7glB!qj delta 140 zcmZp8z}&EaSv1KY%+puFQqO<^2m}}yw0vJDAKxhYmYB}R1;jUV2>#%k zc)($k0!srgJ5b?rMu_Uf&4LT!8TCFgF)%y^;_W~@3y4dA*bIm@fLI-fnSuBUPzy-g Z1;)vZ?xK^=ERo-=c}Ib9^QT38x&W?QBB1~P diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index b4a60e82e..3b5a3b728 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -582,6 +582,32 @@ def compute_loss( use_double_scores=params.use_double_scores, ) + elif params.condition: + dense_fsa_vec1 = k2.DenseFsaVec( + nnet_output[0], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + dense_fsa_vec2 = k2.DenseFsaVec( + nnet_output[1][8], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = (1-params.interctc_weight)* k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec1, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + params.interctc_weight * k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec2, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) else: dense_fsa_vec = k2.DenseFsaVec(