From df013c37fadb5381b715a1ec39e1496a99adf7be Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Wed, 15 Feb 2023 13:05:09 +0900 Subject: [PATCH] from local --- egs/tedlium2/ASR/conformer_ctc3/.train.py.swp | Bin 61440 -> 65536 bytes egs/tedlium2/ASR/conformer_ctc3/train.py | 54 ++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index ab8cb7b4817f3d96fc038b48744c7552b3f31ee4..194e6637a81c89c2d56cd1df0acffdaf5915ca50 100644 GIT binary patch delta 574 zcmYk&Ur1AN6u|LwyH^)%bN9wtC~RG<=p|&7R0>H&#)e9QVwpr+S>l89vf67~&_h8` zc;FT?DKen~k#(PvK<%mMC8~!We8`tD7)a1-f4(9?2R^^w;dl7`euq=}YE~*%;!=7j z+vj$)rG>CXj&|2C=R0eOddz52)1EhSTC;aKvCT5kvbD+M;FY~U_R_Vy=wXgN8Z6w- zsdN8R`N(#WIMzcVZ}9}97(y2NU}3dUq=o{{qXQ1UHHggO9ww1N3aggLJ3K=fgK+T6 zq=o>uO*tjqBnPn^h50>P6Vp5O!h^6@ZnXq6y-8DZPZOHk=u&UySg^agHc-xwiaZL_b7X6r|EB$!@z0@% z9yH^7lgM-UxPw;2u^tt9k2#EC80XNA27KY|7r23wXvZP&XGmH}c0>Zhx*OJ|n-M+W R+d+^2)6kJj1_UD37h!I(ddiANa$az%ll*hbpG=&FTr~SVtAB z_wH3aaql%?&7T;mj*C?g#Z`NShGu!8`TNR`M7u9W?_ j9C7hiL))Z2X1A7%E%uGQOY*ML?~?5zC)@Yz1s diff --git a/egs/tedlium2/ASR/conformer_ctc3/train.py b/egs/tedlium2/ASR/conformer_ctc3/train.py index b8e9c53d1..8630d5dc5 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/train.py +++ b/egs/tedlium2/ASR/conformer_ctc3/train.py @@ -636,7 +636,61 @@ def compute_loss( ) ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss + + elif params.group_num > 0: + dense_fsa_vec = k2.DenseFsaVec( + nnet_output[0], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + dense_fsa_vec_inter = [ + #k2.DenseFsaVec( + # nnet_output[1][2], + # supervision_segments, + # allow_truncate=params.subsampling_factor - 1, + #), + #k2.DenseFsaVec( + # nnet_output[1][5], + # supervision_segments, + # allow_truncate=params.subsampling_factor - 1, + #), + k2.DenseFsaVec( + nnet_output[1][8], + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ), + #k2.DenseFsaVec( + # nnet_output[1][11], + # supervision_segments, + # allow_truncate=params.subsampling_factor - 1, + #), + #k2.DenseFsaVec( + # nnet_output[1][14], + # supervision_segments, + # allow_truncate=params.subsampling_factor - 1, + #) + ] + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + inter_ctc_loss = 0 + for fsa in dense_fsa_vec_inter: + inter_ctc_loss += k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=fsa, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + ctc_loss = (1-params.interctc_weight) * ctc_loss + params.interctc_weight * inter_ctc_loss else: dense_fsa_vec = k2.DenseFsaVec( nnet_output,