From ab16ce6424105056a8a138aa75297b7136cc3021 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Thu, 16 Feb 2023 17:40:25 +0900 Subject: [PATCH] from local --- .../ASR/conformer_ctc2/.train.py.swp | Bin 57344 -> 61440 bytes egs/librispeech/ASR/conformer_ctc2/train.py | 52 +++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc2/.train.py.swp b/egs/librispeech/ASR/conformer_ctc2/.train.py.swp index b7ed26d19b408f2d24a24ca0d2ec800f6f3a7a91..3d9b44bb3aa5b18f0f14fde69da420e0760141db 100644 GIT binary patch delta 1905 zcmbW2&u<$=6vxLuk~E}Aokn%3KhCcEQ!#d_A- znRVkH*!}^%nG>jp1GK0T65tCbR3L2;Jyi(71?T|@2`EPp;=*^=jp%P>mbDnQ=Z_aU5%i|rZ$)#HtN?S^E?JDis%4ntZVJP*r zhkA66ekY}Jf`{7mWsP2Hh(M65h*KcA@+E1Xl<8NR#6t5d}%EO|-f=C(J8O4ea zE%K$;p-^iPa2eS`(K%N)$O!{%1Gd+MJ|~2vDlFN)bclna=HN>uwjEJNC{4dDEF4jw zEy!4&FV`*DE!#CZ5$Z~f4dc)Occo>ATG0}D?plB|6vA#8#lZ?fmsZs@6Pa?QdwX-{ z?4C;f1@3up9@Ymci7&TKEgtZuGU_3g6HkLMQeJ7P`^Wr|sWK1NRVC3Yl}f!|znsc; z|Gr+?PZ*d;7J84}`0+@@9mc*GW9;g_e(i46_NNCJYk@Ps08fGua0{0p0#z^$ex79P z8}Jd>0IT46@EjNgxArskHFyKegK_Zf1m?j96U=5;@X`TKgJWO;jDqhTVC)@m3A_n3 zI0Z_8z@Ouc{Qy1#mp}kcgNML0D1e`^sjK^XnfzB{y}v%sWhYW6)y-@gXCB{~GDoy6 zU!&yG2jAwN>+F``SlUjq8)v^<>DNRhKI884=JF0_Z*M-8p&nA7DG%$#sJk92wLGXc zSN1)j%q&~mD3t3n(O$y$m*&8I>^%vA#cx|T~e71<#&w**T*!f+d{F? zh5I5-q)gL1HWW++jcf{8?*7w$II;~DwnWf&6)wIyYrbY^v7sydI}rymo8E0-QMC*q zZQl{t^lsWu5O!(lPD?Dga;L3fLpv>qD;4IOf^=7|79;ewg-4TMQ=R3WT<_9W28UKf z3tU_x+ZgsRe9gy=F)@jKU?ZqlfDAQ!&q>X5NnzA)hPK%5g6Q%(aMzL!Om#n?jvof( YaE%tLxBG9((Es|s=)|Vhb3(I!0nP0Wpa1{> delta 253 zcmXBOy98qi_7~5ATeW> z*yNJUWsyft{p$4eoH~tz+L)>CVkDQ!h-fH67*1rq+)DlI*nM?b?Nmpm7&`Yz!8v0U zL)SdAthJr3wU`J~*AYhH$^MrWsm=VX?z0fR%|tKQ!UmcU0}eYAau~r7nvjMhBtXDL zB|5_q=Fou}R3Q!lu)$d&nnDlCP*BK4c`QNr{qAjD?-$V4 BHzoi8 diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index aab2d2acb..146d10017 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -1066,9 +1066,9 @@ def run(rank, world_size, args): 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) - + + ''' tedlium = TedLiumAsrDataModule(args) - train_cuts = tedlium.train_cuts() if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: @@ -1084,6 +1084,54 @@ def run(rank, world_size, args): valid_cuts = tedlium.dev_cuts() valid_dl = tedlium.valid_dataloaders(valid_cuts) + ''' + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 30.0 + + def remove_invalid_utt_ctc(c: Cut): + # Caution: We assume the subsampling factor is 4! + # num_tokens = len(sp.encode(c.supervisions[0].text, out_type=int)) + num_tokens = len(graph_compiler.texts_to_ids(c.supervisions[0].text)) + min_output_input_ratio = 0.0005 + max_output_input_ratio = 0.1 + return ( + min_output_input_ratio + < num_tokens / float(c.features.num_frames) + < max_output_input_ratio + ) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.filter(remove_invalid_utt_ctc) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) if ( params.start_epoch <= 1