From 17b81ce7d6453307d8b2bd6fe5f727b194185c10 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Tue, 14 Feb 2023 00:43:42 +0900 Subject: [PATCH] from local --- .../ASR/conformer_ctc2/.asr_datamodule.py.swp | Bin 4096 -> 49152 bytes .../ASR/conformer_ctc2/asr_datamodule.py | 400 ++++++++++++++++++ 2 files changed, 400 insertions(+) diff --git a/egs/tedlium2/ASR/conformer_ctc2/.asr_datamodule.py.swp b/egs/tedlium2/ASR/conformer_ctc2/.asr_datamodule.py.swp index 13ca5c0389d65def46329c35cf0fd625d0263d44..bf679c4f7b8bc4607c225b731f8a4ff9aa3ae22f 100644 GIT binary patch literal 49152 zcmeI53veVydB+zp<`Im+P(Wcw+IzrBpk3*7e&BP-Sk~Re7w*OBjuHtB#aDl$`M?$DpYgbJ zE0tET=uT}te_9@$=-SSE?!c?uj#fO z-=B0l(PZqld~eX5o}8UOFxd=x?I7&BVWSy0rzf2#Y_y!%=?1NV@7DV(#mvj^B?U?f ztWAMI+^)U!tjgBuEqbVLs-CBwan7N&-!Gq(6euZBQlO+jNr93AB?U?floTi_@Tj9e z+<$WAHMHunNz3}l=VJ<=k59e_$>(*2=VvD0=abKS3(udBe7`pN+?o_%%XwmQe?0lT zz3}{uRp2Pt4bBD+U0%C9#cp3N)3>yCiJ`U~#w}3_PBJey=0rz2a_%!%aumoNR&ILckSo1I7Bj8TZ z0Gq+{!DE2#OR`WSu4=8;8+2<6gXWSON7aepOPw$n^y^Nm)o{YjpzHSHvEj4v;a`$t zHS)Ti?}Vx~2%XpqdTK0kj=7QQliKTbR6BIKZZuxc3tNt7=8k+pC4g$iRW;R@X+cS6VOO&)W6uZ25L1RVx(RBXKMK*?jQ zDqcO4?FFZ|R2|RPR<}W3%HP(F=y`3juTlK}qY-&tq{Gz37fvmz-QGgz&G%imxmZ{G zV9{{di`)q%FKn6Og=nx4x$#=f+9+SP2fknPrOcCRO4S*3WVq^~7)>ZQZWe1+b=z%v z$vc*?t%|(sU0&7WKt?QM#);jICpVpb-}m4~>6Ze(RrjwP48rE3+Rn4xfzxus$PCuq zi$M%^<_7V+8*h$OttLZzh2h+b!@$?$yf})ZX6W@re#}%4GX7l^=x>p7LRVN}{JBh4WB@SXQ_T8B(K4YfJq3hF@md&k5Q#FH}3L`hn$-0IX07l}$o+>K`qp_wB zT}-NN)UKodtj{KLa=*@Qd|O3C&}B{EiK4C>F9xkt0!Oaju9GmLbotP)N;!>&*Yn~= zV@%$F_B_j!uvg@$xXF>ZPExHJA|71yBBe`G(PA+0TWY~o^aed_s226zrq}kGbU@e- zLdVb(5h7%YooFeln;dH@=+)?|TH9YyZP$qh5Z^rxKfo_RPnTL#Sr#*cu_hy1y0;0j zdz$y`gQ1)?`~{ELC#{v9)1^{c;%g=}Pu$~u>Ml$y{9EtT4J~)+0%9ZzJi**RHD!M2 zGUI6G?0bFJ_j;}}PJvOjk*0M{ucbr`)dJFn<8|Gf1u8CI$qc~_;kwKw9utCoRcGxd zygeqB$;tzhThn(PLmH|`cQai^NtX1B<1afa5p*R(5r;}@p+#aQZ_{uYM|-qKqvrLc zY27~=a+n%Rx0~KR=(c8~Fii$u=r$>?ws*V{QdZDyM6TP?ELE4k)Jv2&T;fXUwaip* zSPvIXf5>{nH?}d;#FpVCU(LKapBeskTNSxEf_decO>Ao0nCI4EU34Th)$@Ad@mU|l zo*yZ%3wy>Qfb^uDyu*Zo?A347byp{^NyClYF?sn~?Bug5kfCB)ysvj~(cS5<>@+N% zu;lEZkr~aDj#iy$CTQndck#$+Ijxe7PS*emCYB83E0Ea4wBpf1;|IkL{> zN-2B1dDoGdXuD5Kk(wEsOc{1BunNZukCME3CGX);Tx*dI%2+W`MJy|iB2kk+RwMme& znI8@sO^4N{EzJ))9g47Dku>yJY`-FowWx=6Hz}W0Z!*cMK@VZ8<8Df4$SU7xNFC} z+IK+B?c2M3*TG%;_VR6~n%#Sqdik!s+Yy#5x~K^fBXlEVPAEi>>02t)^59z!afb+4 zRbz#u%D#Y5MH(`Vs$KI{b@}Z4uK5Y_IMnkcE-x?F9cHyfw;qI@N#9T~n%uo>ZpYsF z9kpqk(Its2t-H>O3KnDvVDZ}w`h08Wi0s6?xstru3u3)Ui)DIjdsNUp-gNt{$$K(! zKp3aTBHgT1nM^XBm;9L})U2+feQ83yl9`N-Qq)h4L3UM55>>}9(G?a3sMjZlKS_f$ zBOnk%nUIab6uKQpuk_Qo#~H9>2qT@*91RV<&SI>|rlzN+rUJ!@s;~!*|6fFf6(f_}Q9()_RvgrTge8}|` zbp1Dg*MZl9K9~bF@L%Zre*@kRZUL8p3&9281>nc%`+pDK0j>u7!9H*)5PQIV==@>_ zh{3hM16P7c@I&-`u><@j5F5aaAOJhT)4_jI&))}9-&>{5EAM)a+BgHVS%vy>oqPfDB#i zT)n(fsdD76fCS3__f*Dt}OB2Ir zVj>)Bp<}GJPfbI$Me5;csLpz1L-mu>*U*-G%-Hcxc{_(W7n{k0)7_m?mxre}hdGaG zJk~ClnlkMkjyM^4g@KJyFZF55 zZd^7)c6}z)HN%uS*`V)=IaDWTh)q@UIDO-y(8`SwQ}q9H*t>o%``n`cTmSl>qUYZO zz5xCNTnAnVo(9eapGD9A3vfGlJ-8OU5=?_F;Opr5_kw%Cd%>H*o51V94)9#?P4xXc zz+rGU_$K=Oec%(|UEpRg0Ox^cgCC*qe+7I3ybW9f4uA{6bHHQ4zhldIJ2(Pf0WJXN zfbY@1uYk9JqhKCf1a!L}6*Feph>`*&1xgB(6euZhDkvb!BYTxzW2M(v=`}XegHfEq zO0TifYfLo2(rc{r8p9*<=gw;^Cx78$ciI0F&U^Kz#k*2pmua z&jHT{KRN@u1NbQTBhUgfU@Mpg&jSC0PyfB(&%q5K0wFjIc7Ymr5WoH}gU^AFfOmtN zfDbMLF9FX0(vM%WhawD@fp#62n#}N0PmSru9ly~XL{87xs3ie*rCuawFpAgek)sux ziFHD?-*K$Ne-AfM{E3?J(CtZ^rNqZcdl}RwqzjV6wpH!LLf#K{ z7m0<4v{zh=E3?A$t-s0-e>Wb5xa;XWBApQ5WgL;^W#Ms^Z)g?w;>3HVl%7> zA!^dLnyL565q%WP$f`Tj;GCow@?NI6Q@}R1Rhe8j!a2>r*IX`+gGt0wkv&X9MVu0% zwE&-puVW(qGOpu^Gq>iyxWF;Cucq^3e9dSei3KyB-wzcFRuK})RUG!>NKbcY_fO&h zkRk&Hv$rZdrrP#g9|v*W_#>O60lu%tJnj0r`RI|&$kfe8Q`bxpu`yQu?8n>D}CXg%>b^~WA<~hk(LPcy8&S)a-bbMmbDz?w$)6VOo!yeOuYLfHG z@o?2l&e=F1nJ#49uOeO=6ogQaQBzksUde`s8sBW$V%19qle3SAu zmez0G4|*M(M1(FfX*-Ds$P^|$YMQ>?jU_ywm(2Y{Zt7^!Y?k5_Z$q#QwNF9gZ2{|> zwWYpt73KxnPt}Q;xZcR(Ref0Gn~?)0h63WVk5Vn1G(MCdxn_^s$ zzO6?3XTAcuH|T}6!wK^pu}N*0C8}O7i2P#3M%6(KlAY* zTtGu34%O&>(ae#l`o$A!OMS~xf~jYwC%29tB{DUwGntu?%8JcaKFL~F0Mdzd$=Ft# zHs+bO6Yt6#lLWTfl$$d1A~9PNoF<{?+kS{FHnG06CQ~*CFNj1A+^DxXR-!>vY+mw7 zbf2_E;`O9qCW@;c1`8}7S+|=QY}SO>FT79|_AYX|n8nuPM%&lRXacbF1!%3s&J-=B zU(zws5_7k8LM=<+8cR`$&tG8)8Z0Pis8@_E!wQ+rwnWt*26m&n7Ak)USxYQjK%AM^ z!)a|rCZ}bgYay%9?Gmo2r!F~`LnbX*0Qbe_*IC5}d-omOu}!l|mIqQwTUx1fJVKd`L z`C#Ie@v+Q^-pCNm%gLc@vaB&M!kV0twFP1_*PZ#K+Y^HoylumjGji2M(MS!Z!mN3M zkG>;!%>gik`$kI*Icj1?{3k2&q5EQ~ryN0(Y%6&M8C4yhJT>8v(F=x7W(Ql4Vk{OWGcLZ%%85!Je3MbNx*-!$>uGA&zyG>a-2j_ z#_rVQ#GfOzEgkWrKhBC$qO9iyeVo{zsxWg7EA>aBaiNq*8B8AgTj`WdPtXN>a zc7@(QXwAEc&6oa97l~`KWRWLwnO1k`W0$gq(+mscq;9e!WXnmpp~%RPpRJ)B{qwVW z?vYS#wGffJcICNx_!XW^D0K78=;T{PGRR(Yxzi5v?30=Esd^=wEW=xnqR48=0{m?nbpLw&SEl< z+dhVU$}sPG<|=5-c^KM`FhgUQN@Ty2wJ1xenW_2(oLJBc zXU4YR>OVb&<&s4XcBNx3r>5w$#_}W{_DM8SmZt?ULUM;+WYz=LW@v+$bcHk#QYZd= zsd;J)r+$e_+9R#e_O=XNR#PyM0hWNG|D!TjKAq_Q*Cx8^chLDC0OI>EX8^Rp7H}r` zKKj1c0X_`g1g-&c2H^9-=h6Fr5BT6R@G_vlS>PM!{GS1D1INK0Fb6INF9ze_YuExl z3+@DO06sVuh(CbD{(lR|*?-~-&;WAQ-@jrLcsIBm+zhS(vj6`R8ulO%Tfv>+7&ro6 z0mScL&h+~#_$YV>xCykt)!-^H3oZj^fX`whxC0yjW8k~k2tENm1Z3a;HDE8;3dX^M z*bF`o{ubN;ejOYJd%%U@0`LMLd;cE-Zv|I?OTi4706)Q&@OR+tU_aOic7XH2nLziM zwJq=|oh{c1m_#oQO%zoc5o1`4dC#mJlZYn*W=?ku%P@K;W&et95x>mxK0lN;5V9xR zXeO>oM+eF9OomC?ZW9AV@X*hu-7JzH)^?MBRy-8j&EpNXJLwc1_So`=eU+D4vWiOf z7&G&B=5Br_ON>S{$!Gqd-Y_#dMcySt*hH$4zMfPH@w*+Bb+0SeInWY%#WrkU%m`EEdq(bTaO_ADM(uh`u$JuWWZ?CzXHTNCr(!&& zM3?5v_3)73gi+Oa&y$qnP20Vh{MS4WZCoa%mXM63%xAhTIcXX(QcmheX_mIR1tB@) zsC@f)D|_}O6G?W1LcZ~aIpk7yFgS2?46TzD7}tav8y~NGvD@W%2-~5XL>oQk1%n7f z{et*7z|qzk(3Ruy&|=vp&+IB}Qk$iq&HA7}do-+A@Wp`89E4&YQ)d6fxbL#Xu;}4Q zHT0oM=P;>8j#m7Gcg}N6-FG_+LR zNk#PkN9y^18P&J^v!p;tfsz6x1xgB(6euZBQs5Cz0a@S5m0z2u>+#~&S$Y1S(>wLM z2N*=nrinI!7EPt6`OG*uzfPa!cFNm1%(>W19+c<*Rb(7qc8U+J)cbMe`G4c7jpgJe zou=jRl*5gUj&lG1k-h)_De>D#E*YnCJ`J|F5LXk!*LD`~N9P%Kd--EkMc;mM~d=&ULwY zNl|0Y0-@)xgOVPdB2SRFXR7(54bkr6J5stuiMADpkS>( zu3{rYG(YbT?Ypa-9Y;7l=abrK%1pR+#{Y7ZyP#o;^mpF$Wbc>p|2XtkEH&iQ@CBWm z9Mv*1W;EA1W&D2`|F6%5E#v<~ecVoRGC&#sU&jBt%7(978M~gjYOaj`=RcW_i2pCo z|9^Cy|Nm`kZ?KL5nNK@la!PY@+KWDZB8{v|S1YKF<@x^^+h1zFv*B{7u|`}J&9^kb zJ`E&~#@2H9F~^^l=l_@I|JxzT|5PZ?|A!bO4+bvJ{~vigJ~BuDH;}?9dgb~5@R6lW zdH#Pr`;W>}Lq*=0Q=b1Hb8>8X{=c4%%JcvG>l-xX`Tu75QsunM^8Ejw|MUO<58P1J A+W-In delta 16 XcmZo@U~W*@@RE7r1=h{)n1y%%Iz DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments.