From 326a55eebc5a769b9be9f4e5bead11756fe0e9a7 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Tue, 14 Feb 2023 13:00:21 +0900 Subject: [PATCH] from local --- egs/tedlium2/ASR/conformer_ctc3/.train.py.swp | Bin 53248 -> 57344 bytes egs/tedlium2/ASR/conformer_ctc3/train.py | 40 ++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index 38d6a2df6b028ce712a6fd4de3478d546acf7eb8..e29f93389d3b266f3007fd7d760ab6f7460ce5e5 100644 GIT binary patch delta 646 zcmb8sPe@cz6bA5b=DqwkGtL_YMGAc`CKGeiHVs2X3l+glpbSJvJ*W3(rgRkFQ^*!G ziqj?p^MDHti*^}8I*Uqz0uyPYP%Ui{_K$WExXGYzY;x~~ANRuH+=An{#tvq#W zF0Uxsva3s0wCV~U{;lBu-JM(0PWo4M^`bAL0%ONm(10!pT-!}&i-9DaqZ71O1U#oD zdQ4f0QX~Cl=QastX^Qf<-KG$I;_dU)Ln?h%fOj-QgH%PkZr~Ly(tsPai11KCUtGX@ z+MxTiK#;A8YT57i{5%dum5&o8g;*QagQSj_r1ZvXn07$MUes`G-f*&E*-W~zV)7*k28(}!2O9RgKT1mFQO@N*G3H%VY`5K{v;E}oeqdNR$}+Am@!(~io~d%#RA|f` iBQwi1%(8D*7=#`#hXvLeo?QbM%rr;Csif7|4gUbgm`FMR diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index ffbe41c7c..0557a7c20 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -543,20 +543,36 @@ def compute_loss( #token_ids = convert_texts_into_ids(texts, graph_compiler.sp) token_ids = graph_compiler.texts_to_ids(texts) decoding_graph = graph_compiler.compile(token_ids) + + if params.interctc: + dense_fsa_vec1 = k2.DenseFsaVec( + nnet_output[8], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) + else: + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) if params.att_rate > 0.0: with torch.set_grad_enabled(is_training):