From 6109080f22d301e27744e7963b73fbb6e15dea6f Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Fri, 9 Jun 2023 14:58:05 +0900 Subject: [PATCH] from local --- egs/librispeech/ASR/.bitfit.sh.swp | Bin 0 -> 4096 bytes .../.train_bitfit_tta.py.swp | Bin 0 -> 73728 bytes .../{bitfit.py => train_bitfit.py} | 0 .../train_bitfit_tta.py | 1567 +++++++++++++++++ 4 files changed, 1567 insertions(+) create mode 100644 egs/librispeech/ASR/.bitfit.sh.swp create mode 100644 egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp rename egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/{bitfit.py => train_bitfit.py} (100%) create mode 100755 egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py diff --git a/egs/librispeech/ASR/.bitfit.sh.swp b/egs/librispeech/ASR/.bitfit.sh.swp new file mode 100644 index 0000000000000000000000000000000000000000..db2719e5c66e49956a7ae4d35ad66c0430d8f8c6 GIT binary patch literal 4096 zcmYc?2=nw+u+%eP00IF921VP_6tyYmSxzruU`WZ&NHs7vF#<{911>;ysim2^zKO|B z`T5xaMfrL;`N@en=z8j42I`mR7iAY0BqpcoXC|knCFbPlr=}O{=VT@oWfm8trY2|T zI|c{oCuNqTWtQj_XP_H2${!7Z(GVab1b7*YjSRtAU0F#{L0Bl13^PWxjE2By2#kin ZXb6mkz-S1JhQMeDjE2By2n^K_0066ZD1`t3 literal 0 HcmV?d00001 diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..10a5113d2266a22ac322debba9cffc453af4ffa8 GIT binary patch literal 73728 zcmeI537q6rb?*xoP!UDNB?jZ48p%{M-PJQaGe8;|9cGach8dX#qT{@p>aKrx7gJqT z)G{;8IO2jaD$x+5#1)sQQNZ}3xI_$bN!-9BxJSkK5`(fBBQGy%%=3NEx%d9>{r{_a z8^pZ#d0p_E>iYlhdhUA8J?GqswU=z)5&c$iZH~{0x!l%UUR=56vEO*eYyUWxt2Acf zwd*&G$g71uTL>3-t3x}=ljk-X`**b(#ad&sT+7CJflF}sV57Ca-7HVW!_~=ns$8oL z$J6cMPF$%~yR++uH|@S?xY_E~<4UR1D%abU?qu95wL9fbT#MW7Qf1^o>A=V^@2d6E zM71+j?UXv5aO|8V=fYQG<03*_QI!akq1-|fu%PR&%A%AjknEy|Al?v z(nmC-pRWBr(9vzzPLc zD6m3-6$-3SV1)uJ6!?FX0_93BC%oSig1AD+mumKzk?t^Ln4)_B2Bk)qt1P$TJR#UAB=-@z_H+71QMSG9|Ugz zF9Mf=XM^?N*TI9qJ@Cn20Url{4Bi2*18)Iu2CoKJf;Omtac~xRD!2%j}aZt!$4 z295{+Oq+fHTnnxOO>imL3eErp@SEUN@Fbvp!1m+nY_rkoLm5W`D(wpLXBwq&U_$`zAavr(;g zJgKw<`mA)Y+LDu;2YQ?jSRJq!z7%uZlwLD#Kv^z9s&IBN@^2mXB61qzB zw_M0=- z^zQELY`HaeX{$=xhcW8)Vzb;ry(q4^uv=(1)6ptYlwIyMRNt~?R~UJ+TPYXIvrQNC znXPhVH(E`pEZoa+EtQGhXm+Y}UNeBUR^s|#r12`7D&<)dM4P|`VH{W$C-bs4)9AG0 zbg!SP9x4&6(wHr^zm+N`pdkj#Hps-Vt$%;<#wwiC7?>(T9WMr)rnShcQcMBXEwE4Lg^j)K4-~dTC`r5 z-d?Y#pU!q`oochym_*f!>PV2;W-5sL(T?kMhsn5EjVEIrIL%r2PFunK*pCjI4xi3k zQMsBow=)UI9;M)zx~m8s~ab&xmeIUj!@T2UX&*%yR$>njHPz8E+}hU-akYUhH5Qp zwsvC@z9}zA&AwPm99xfWJ06;9t~)(UW~vp(SMZpQ+(CVC$V`20s=w9~V77c{sM2kf zrS#FrM*ku{)TElNsJIm6!>9R@Om%CuAv!^;8j+YIKkACD&QRQJOwNQwtW$k`*pAlw ztWk)pRZI)gBrXXZ2}5Q+$&j}^QHzJ#O(<6GPDk6yHHN&SOts$Z(j8Eb?8MV^(WY*t z+AxzqI|-u;d(<&)wxkNHwk0Z-B$0zsvOCy(@s>@cOSbRce(nWZOIx;HvVHT`-DCN+ zgX;!I2Gqb0pyXq+6;*<=)qO1B;wO0~ttjcfH`Ku<&Dp&{2a z3uuL&QDTG)6Hx0oG!*s^O*YdY`mDD#hK7Q+*g#qf631j3lEzeawt6UPw&F?M_AFTG z^=p5bmHyS#aVgEuymfJ#{l=)yVXLBAb+(EEuj~S&Ze2D8otZegfIVKF#g0C*JC4Ig z5jz6?EuFkNRh^7#%w%0sjHcO{=+qjqvz)E6N=Txtj-+zw_VM#}UOXOMy6K{eHeEQr zee3RM=S9)xofmG|KE8eDh5XtUZMyI|(fQji+%g!&Y|AJK^=Yx#MC2%2t!iA60?8K7 zZWVe<997$?08Az2dX(R^JKDZGADz2t_x9a`q;YXw6qch_e0jIpiYw8?TvTq7!=#K~ zYIIQYEKf6SRiaKKs@9`}>|z)n>@yov9hQ(dw~BP+Gt=obM~8(B_g+!EzEjcznso{Jh29fn*~j%2}!8dE7Empg=+drrjnY{ufEq6OJjaEA$mQrk4|2#LYPalk{ z(R0d;0ZZ{^#n-?)C=}ZTdTMk*w=sD130+D_&LH9S$R4+{g-oP%EV^*Q1HuZF+> z2oU}swo*Ue;rAWjQ}*45<^9TMg#s%SSfRiQ1y(4qLV*$! z|7P$i@KR6*mx4>c(?J2O2H)h2|1ZJEKoeXD&H#CE3ix&KkDTv+02~3&2b;i2;6dQ) zoZr6{$XWjcI2Ak%JQDmPMEWTB2)GWs1+>5x@EC9+_%Y}9w}LN#cY--^DL5U-x&F!E zvEVZZT;2+<16P1;U@Lee_&3h*zYM+v-VI(38ekf%1E+$Yb8deZcsHnmK@fo_fJXuS zIN^{ShR(+HfwuOSHgd#fT(1;zE73-v%v6!Y*u3(|PKZDTRxedEDDgvH@JgW##B1WfSY) zDVG*%Sp|?Ovn^e7Z%)b*bA+a+Qnp~}zMtyvHbmH$mPX-TQZA{)F27$E5j4I2iMgzn zR4zN2xH$b&On1)m>~cc8u)U?3|4nvIYr3F|7H__Q?s;!1dh*Ilu=ji7JL;@m#6>IJ zhdfJ8+x=kZYv`_x)J|3sykW(p!XF1aw- zxC>>K2|G_7?s0a+xaaewGSB;7`e_174t*38wWq-{a>-V=x7;L_sr!Fz7*m*Eb`_7N zj>{;+q7+;_fX~hUwfNpRlkPI2W@09dzT_l=(d^#}i2q4hJB<#i*zNY7y!@Nr8x04Kx*a$` z(d6KT^8dTwSFeJv75;ylGK{(0FW~XN0&WAM10d)BR{)U({1hJlPVg7t&%pCR2Rs*? z4bB2j1`h_Ghu{Akcr9px1~?D=Cp`Xr;P1gVz(>IAz-3?y*bJTmegV(_zk%2Yyct{$ z#75v!a0wU$4+akc*TeUVJYW`xEx_^Mr||i=g5L)R!2$3La3c6ReE#j=@4(IAo!}+l zOdvJ@CxeHAzlSz|0)$3~!4=>cU;q@rqreH^7;qo7{5lZ%z^lP!U@H(k{|AGA5}qGi z53U2R16P6PfDK?h7ywTIzd#1?E$}fQGJ+R@39u6s!3p3VWB?)`_z1WTyagNq&j&+5 z>;N7CzJ)B{dT=GU9BcxQ0VjggCkj60O1Vj2N%J;!V+t;L-BKFDLeL!b-_5 zUuJ~GgD)N;O`FS0s4-twoT!%Dt3`>bF4|tJrzY3gM~kKx21rvke3K#jyszec;pW%} zB)}qx9{uY@Vpa-^2B;6r6Qs}M@KNx0hW19D77P&$x}-oRQ4=)Cp#PuQ6vI&2cgfa4 z0`QenUZbs6g$zprzTGs3c* zE3vU;8584~g8n@iRpJBaq+;OI9^1n{WB=Yklu7HAQn!wBX_)nx=(PJ8)7nQA#TcuQ zzgTTQNH%XmP;phvE~%NrIng>(Y*VRDqhN9!FdI3asM$GOW2w^)heru<>ZJW!<0=5%L@5_V4C;Oo4~R8#6BYRZjzZBF!hD@Joy znj{jPb6vmn7pF#R19D2l(ZnKV4^~t=)OcraSjbGJ#fU9NEqkt1kIqRpzSicfgr0V_ zQ8Cp{HS<*d!tU$@D%j{HYmE)H^o(LOIxL4)O1>ku7bFiTO1nE-*wZZXD6SO-WHZt% zT1D%2Ni0lMq>D+0i~`KK+z8MFHHX5$UfW?z-eZv%1hT_-E~oWz!8D+uZ>se9O%n|a zsZiXW5pCeVK0&zO>W zH;X1zr@@0l{_Kg`{&OsSRz>N6ry=6-dq$hpcR~uBlaS^KiJIPPeirg# zZkHd7q(7RBX-^YE4e_e&VARO_?ZSm7sN!lUc@CPO?kUDnD*5Vj`V5;)+eo zn3+n@xK-$rJ#U7=kn*EpeZ75ZL5MjROqC@|KTaIy5*H`^@Y<^CXt^!sEOj!t#!lcW3ObbgL9w~wZrsbWo5ZRfN)+ObScHYMZ`!dBv`2>1DY z+F!^3pOF!K8GI7l3_b?l4z2>T;8O5xATpaXz>~qlzpbRi;@n z8)E#3(a=bw)~4)grXp`PZp}?GsG8&SW>Bn^X3JRW%k0Fh5_w~3FBvd2%Gei=<=3CN zVK9RrnAok3VgBVK<}pGvHtS*v+qPksUMD|Am65;I)%bKqv~XO_=BrIHC{|MXGY9Q| z!OOm7u1g-vF6=fYcqTo}&niI^Kyqn<`3l-Y_P!-ww{~R1p#C$mcEg#2`j7lvtRhps zNcXn^?q3B)Fl=3qf$O3rW@DOk8PhQrY6qFKWoYP@GZ`RlS#ntfM`gj&KZ{KDI(ByH z3z@^EU&|*a|76FKvFoOju>)dt4an>|6UNcLn`=aY0j zHELv6QR>Q7h>O!jCPNr%@vgzbKXL&6eB?l}*l%pE=FRF%5$Br&g|qK>H96J%S`@)tH>)bxn#cROY zb(z8|XY8NRzto+8@iVYc>HD^}>=YlUk}tLBv@D(w7Am^*)4s(Q%U*1$h18L{6u@0X zD=Ls8r$dlrpbt@(I&1?+1q+p}Z@Iijh=>`pU(;rrH9Ks43d>)XFBYoIvSW`b6S~ho zZ`aN~gir?yz9-E;fz31<3sFE?Cp(Ks5sW(q;WKtzu(Ro^A)(#0w$!H3lol!i3F<=A zikr(@W^z)TY~kD7$Tiu4T@#aj&(gV4uZFboVsna(Y{slXAs-^*{RhuO;@V2z!}JHgavF+%kde%k+iZbbBF7}o zV3ZHS<*jH(MuK@yzJ`)T^SkW?){-KRaP2gf3Xqc*@>wN91mlpPHYdzPBFN0iYf6H5 zjyh5oj>4Vn8mR##kRc*E^FhH(fKlTX$csdAAGBo`0`4H>Xr#-|Ks0oYigekap8{_JW8jF#ycmpwvp^o42>t^;{uAJ1;LpL^z{|lwFadZ|g_rR^-!{D7@7HkK{ z1L5mG3qA=x0saJB1Fi(Sz-i!VU=v@7G5NTqU?bCT zW(DotpVeAJI+V~&S^9FozD``8>C|mD{U+fgYw;*GXk`nsi&F<_>CLB^otM7&h@&r= z=!GFZch^=k1KBc%wa$xgv|({wQxd^CL>T5<=CU~C#oUI%6G$r6jEBO2ODP@VKc9y9 za6uX(Tma0u_(N20ggioW%=6CVAbceoAkJwjIVrY$5J*~|JE+G?{SbGv#y47X$RA^Z z%Wg@<&N8XC;+lw9avGI?%-WpK3!RabbMPaA9!bNCJYAQ_FU6fxP@le3A`|i2wVeJE z(x}>j7$Sy|ntB>VI=|YqhyC-R!{&TXd+Cv4)3|$-%Dakh=NzihhQT?&+BSq6F7E~m z?xu>nL~#Tp-ijoad>lEl!y~)THA!jCM_XH0A=O56bumBIrD{VhaJj?T19ZW$(YnD1 zLLIEq@kRe{&7UODCTqyPeSf&!#Lj)_D(Y^Y9})+9&DYXS8oqlZ0jFTcNrux zrSuXxEfYO};VcF2$Ak2ly|T?C@*cZgTBs~Ss+|0j-AXU3JnI+hEK(X%uZ0TJw;;Ps z!*G~WZFEIGDdGd;d5Ol-BPl&r3{JrMk!-0POOVX6i@HM(vgkl;d1yZ2ab@e6I{dMT z3d!9_TQ{b#O|PIE>!BuOv3SiULypGWm8(jvw2Ug?%m8q;v1!iNK`c6+Qx zh`7aw!`tT}ci*P1OknL4X0c&!-4kb*3LXnNh?zR3lZ~zr+4pNl?j+zA*CCla(5T^DmGd2_1$*SSebSE{!Sd_Z4F>-QW2f(aV6tA9T@#57*{qgb z)-8>Lb>~etnkviQN2iq-uX}P#kGIo6%4%f${V19jN0=Kpsu@|V(8-a+P$kHF$mH#a zvZVc=9CG7OF~~5#ql+(VX701cBCUZhJDLn=CPmTK*=7}&+-&S{h^rn8b=4o)cshJ= zJgZCNYM12=&Wf;WN}fd+UU5S@UB0I>nM3!eY)z%Aegum=o-0Z;&s20w-0 z{~q`o@CG1u02MF^o(1I2fCwB9J_z6cK5!%WLm)Z;WpDuyTY!(j-^=-a5jcMUa<(t` z0^9*U2i^(fY+v~PmjXH8FM}Q681NtP@jnCK1)l`M^Zy~Z5S#~|35wu}KzRNiKtws) ze=Yb!Aa(*}uoH}ee}#wt0GI|(1z&?-zZzT##5cepPy*Y4*a1Y~f55B11I&Qc;K|@5 z@Hp@@`0_iztw41DXTZteOYq|_1!n_k)4QZy&Byc}^RKQ=svvk;RIp+_JklGBmuM(T zEA9<&LSQ6ZNtl5^gXJJBVq+9_7h!5OQ7e&)NIacVpM;_{Iae|!IA>CvEX}&Zm+96| z4O?Xz#>YBhb65;Cih&(w_{3%a~ z0vj8VN#hhfA(@*D)-A>lSn_ZJCaRcro>{Jap&~JaW)i>=bDEP!mDK6h^;zpHFP4Co20>W>ESAix2E5;#<|6T0baMlwBTB zUDx$tnG#Al%I_3)*!?jA`b@j}yhUf0Rjj?gclWj^-z6yzvDxkn?LS%~bC#m}3+d)c zMk{yO2T8Q2Nq;pjY_C*_yR4r&Y}Wc>U)boxqrz4s?}GNITpOx(Q>+JQ+C* zX@1KuviSP%ig-{q&G1`{DmK&A1G?p7GuUL$r~PbJwH@O_)z1AYs;0-zCb^Sj*+m|w zkBfGz(iZhuF5Hl_O*u7ZTg~*trmo^%$IQo-?qqeMTB~;EY+ex`N8=46e&X*dW00(t zT}Hi3L<_iid^6&hcHj2peuzxwz0&sdl@iM+NZ$GYXVpvAY&eX7*>uQ1F2{dsmkhqOu@&4m{-{j|kByNIALYE4G6+PD@h(@lvk6Lv>f ze&H(zWOd+S5Yxexp#aU`CYqNb=Mn=KnZBRRkO3)uCTyc9&#{6QSM)#sx2-0Vc7 z)-E62yE>1us1aRXaJ0WGuYxC1^Tmc5b8t^^%^ACXQf!yc!Qt;bg4LA;n zzrRlaIg`HO;jja{~nn1BF z45RlisSV$P$$E13zZ=ic7CzT9O|ERRlN56@Ro6AW2jD8g+?He$-<%C6E7YeqI2AtUvdKB~f0qN3#vlWc_y39{!) zN0NF={qyc7$NZ?x(_HLW7)Uq&syucL+@r@-Nl5Nr3br|tPWt_Zp7*KfnYekjm5jMA zR%2*$arV?Ei}XRL7Ih*nQvy03uS3y|=I{5^#*~e1y6zHk^45Lw?e2s-qp$;8YBPet zakNTT4r`WHt=layyI}13^9|82brMR zHpqS)j@^iR*p|o*Qfl{O2A^a)q)Ud~36{+9*JxmaY2O%YERF4F1le36CPo=3XUZ?h zx!Hl$W;8mlTH(%2IYiPuvQaG;D<))8?np(akyH}BqhU1CDh4mYqde)6*;HJC>^D1J zv#VFZnq446Vg@6}kW&+6rwk_` z#8Y1W9-V65B;!D2&uhcf(&K0}x=6;AieR|WjGJ0TcAAV7l1qp0NiH076m6Fg+BnS~ zWivAghA~uN85A!=Wcg`!%bv1*f(UDXW)6bqtQg*$3m}p~zBz^S+wx=!hAQXM81gkW zvPM*R<*1$}u{W)SgyCtRjUw$A(ltAGs9+W1ORH)f%m~sJ=aB>t8&N2g>a)SLpwX9- zULrWjq~_`yd>dUPj$`f7Dn)X!7o>GEG}6_nH4gTxkO-F(gJc=Z0hOeTqtUs2qS@Bc z5dsBS4zx*(HG!vjP#gE>15+zU7&Bq^uakyRw`#UdsYpwx(bhvv6hY01W==KRDRbMi z>W=I)Gygto(kIvQ>F!HJ_t?0|MS`@=Y;>Snj$ByNlAFcewO--&Ut`;&_aP#lU<-!1 zl!z@6^xM?!SrRF67a?JI4E5!{K|^w64n222KC1EIaf4S0Le?|&uu9U$lXqT7EqI2F7f{{DBt1UMV~1`wWK^!D!ocZ0jYXTgo& zT5te}t^T9Hzrg3;25tr7=Wh&%FTbwXgs;BI*QJAnB2zZM(+n*{oA0 za^NgxoQ+V@<;)%@l;*p=Q4++HS7IjpasHS|kaWy8M~E!OhDosM7#uV@9aN=(<(Bk{ zY*&z^yQh;&s;ci(fzttPKV-afK|A{n>xRa=pRNoV2!_CGX2@>)6)U zlP!hG+yv2LnQn#NO8886gUO>{_St7fCHDoR#I58I69ZZJAqht+x*d??5~C5b&y*)2 zG9|I<4OaQyj6l0t3Ei%-IbCP#YKL6d!P$l`WbwfI9jXaPIh8mrA3J(erXs0alFq-c z810nZz`<%;^;k5S?TMH#91E1m*+ygxF!~io6Jpd{(8oeUhioyY+oYW;$1B(NY+>r& zBNqPAruy98y{6%PzxOWS0@K5Bq>s)KoN2hEB?6yuboI%ZqP{6gcC(1OTIp!MTYJMy zCPSGnO}ZLj)W0-OZ68hV^#*l2DBY&yC!dU`KGQAiV9>OcUVf|Au2!79w%rt8YbL{p zkI_qaZ_O|N*86dDA^Z~`${bLOWZ8xCPkbmRRSALxk9@d5Wa5$w9?4Q{JD+=+k{X)@ z8MeTA^m3udHWU``8CthD5N4SIkC`uD7F)nE`|V%YRLAVMd7-4Ip~poig!%*?zO*6q zACok1Wp=Xk(FnGnrnPKFHkhMNZ#bK;4d?2Z4(BA)%v=4u{w+sHZOW_l@V&}+1=mjg ziKmpUYS8Nbi4Wy!8Uza-`EahLLAc=21JF54VWIQ0P_C2)byT&r2uZ#iirU&`NNUfC zQztCpGfxQTHPMmY{q`XO`nf>CBOlH;nhWNim{2wvMA%|PoSq6U#iJ#IZ9KoV^xZ4s zBL{e{eT_7BEOoSFXM)~f!O)g;u$&P0L#UCJtmQ=Cbb5nj(g|EO>2%kxFqtszOcLt0 zs&DT4CKK9K-YNM@uI;|D`-aSEqp3Ow&*U4Mav)oGFfMSN~74uOO9Roo#=XT zr^Hw|Xqqf~{tx?{`=RCko&Nv#;P*cd{uI0iybiR$k{tX$xT|ndm?*wlI;tQ|@E&`{6Q^B3c0pwo5 z_kuTqI@kx~4uGeD#{$v+7k>cn1+NBIfW2TF7yyq1_aPS$KYyYp@G0;~@G9^u@L2E| z@NHxQuLok=zYmna4ln{v1ot2x_$K%?crW-J@B(ly*aXf1;@9s-$OuG#;9Bqs@N^(D zgcHDDBOf>n_JJpYzdVNFU=C~t_rd%B2{-_{KSKLCgWy@Lo{DBGCKBeS-Oo_x`nCH8ro`mI1ToV!nGc`AZY@*v@-4EWE_T$Nehk0H3cq8(rlXK zDp3qarMQKim<=K4MTH9oqw&FLvmvf~krYl(#N@USDOKXNS64&aT8p=E*^~75#$vXr zbu{3N3l?z(Yeme3fdLy`6ZZ0g$T$&eizn)JMO>*Og`Vn)S5`Ujga_x$Q<&Nr(ft-{ z53}mQY{^kryH zTV&`cyW%BQ6tis}yc4gms#t*$L@J0oxiP1ZdTwx*DzaP42P079N0N(3i}lph6`&Yt zAbqARx2A$>E$Z+#6_qek_~Z}=NtcW61qsrvmFE;~M6?#^)+lQy8E-QEP{a%3@S+&) zj~6;KHYGP3$-) zv{-?4jHpaltWq)sPQRSLz#4AB0-KZ;E3nU=9=Lx@{xNMS)K?Xe7Gidk#9P7RvMAH@ z+d5-1J`l-k$Nhh zmV76z?M2a&xDgF$KCDJU!swLbVuWyUkx9uA(;QsO+&Ei3)bB)3Q*$kl%s!fcp}LuU zNKo5ZmKZl%x~L{x5?7U45|NuJ0Mnh!%`OQyF^tWPwNDspn|`l~MLjtVM0R^1!v8&>ow(nUiJY2Q&M;toB=Dv6mk4WyXQWS!ov zqb4o7v?C@6xw+ffZ5k>Q(iJTT`v415McmrJ=*By#OBd;W&F*r_5)42$!CR?nc(7(> zemlQrSPLP!Ul|yBb{^itEf6j1bq`Bp?yZB{AV+&?4BSTMy&WTwY#%f4WQP*Q4QBQ* z!{7v9p6qn)uzOdD>9R7L=l^2A|4MKPI2A-d{PKMlh+jUj;s0YGzWFW(`@ni2_5gRo!`}`*0^SAw z0Mx(^@HB8d_*Z!P?*P#Wcs-Z^7l8AD*!cefo?dwT8-VBqoClr^L=WKY;BDZQ;1%HI zpaz}|hQaYb+VML9&o4c~oV8VIqv)T=u?!AZYHi(>7iIq~$3<>m@GILMdS)cGQ?wI_ z^ji+RIKws5g#5L~d5hCdEu+RLxoUj)K-{v-a97A=v&+9?L0B-GA`{mr$7)zJ&ee8r zX>pSb<&Q;k3a{6skdBWQW$(gTQ>0a#a*16w>h+=s&8cu#IlwWTrisr2Z28PKTJ{67 zvTMeM6w;ukN>f(qLMAC;Ni;m{_O*fXtJM)7D9W`-DFR7~u)1RI$k~t5g%uMeZ1j`W zTOGH{XqYONg-O6kol4wKamuQJB%rP^=Au;F2J%!s;ya9N%oSf#t7qomMh@qrd{NE` z3jAz{m}TI|YQLZD6Rj0b5v{M_V^3aes$J))D`xSoAjm`O; zYWU)iN}a}R>*|kr0=h80ueH9ko}riF?aXf0^6K6eC#BByB&lqAML3=}W)h>H=n1^Gh-{0|$EwKM}Ye z3ykX|y5ttE@m!-#2TKnzzdbE-X(bVCDc$qT^<;fY8u3XtE)zuBpv0X&^}I-TFh8jv z8w{;6B@1i$BvvRcxx^wRoeZk@6kw`dE**#`IpGa{WdqXSjBRPk=8QQDPJa5JjudqJ z{Ub&9tKdg!y7ma?8V3&SE*Bu`g@e_pYFrt!!a5`Fgghek^mSboU5YKX zv)A?>Wph#$`Q-Fp`;U+3o+deRyyNmtMUjBLWSgI$o+K`dfS?&d25&s6*=yK>Kpk@? zRtf^PJuM{zw~j2d{%E*G)*x5dRnb1Wmzol_aZ*##cHM_SuvUZpWK%7#UGLkM^U*3* zyxO!x!in0vHNL;E2X$v&5Qxc^9|1!pa7*FWfo4RhMB*2?dJ zk|`{B7Oe?)SV@Rr*X0YH?vbMDhzFN-q_DR5TTF~S9;Sv>60bIR>8MmQ%Pbw+o&lHWSQ6W$gGXMxyT*1--g8TmZtE zlFK;FtWl7SEDH!LR7utTnaoT#)^^jk@K0Z#|# zfDPa`!B3C@+zq|}{v5mxh~K}hU?X@K_z7}=&w-DDmjSU4I1{V}KSmDl58w^pRbVp^ z|NXxSP6CeuVlVJxa0~bg@D5N0VlVI;;9mIuo4`B35ikR`flc7w;Qc=iPiD;kcR z%A(%-qa(4+L@jw@5h-GVkO3b^m&|Dnh&fU&s1BlY^Z=4H*r$tPBDcRPzi+l`sqhjv zU1XP<;ZhiCdKMe56j)l?S=NN{2AnVgYx~+rxJ2bUm}b=6B0VpA(F0D6iZ9 zmCkqBlNgt(-8v_bWR#W6xdhFQLEY4FUw2$L{VJ;z;ba58-TJpq*s-K#MqRU+Y!Vlk!U?X^ zG9(2_iKpG{3?pzT2_Z!mUBzm>>pj_}%5ip-O2u=`5_JN{H2ZV%mwbq4dsMb)=eQT5 zo@4dmO-dqGPD0y@Yh*(l%b1hP!3sunT&cOmOma$&SVYkp3W<)l8HMn_kC zg7<8dHGq-HC@JjjIXnTxnX;HvvttWqoY6^IllY0ex(%F&elJG|Ri5>%JUJt<-VL26vmAaD2!-<3Lo+*Z*~pg!G8&867XL9fSV z+)iUJ+HRL|JtQ?VYFHBu?f`YkB==I^^FhwOE{R_X9t`irt;kUxzEfp&WV-_WQZnf1ZFT~ z&gR<~&!%K*@mKc+P^+8MgzqLHP&!IMTj zG3Y5?&*AXF2_^<_iPGn+I2ci6BZtUV{^ZAqH=o&U_DW{N^*N0d)rjDLIB|lL4r<25 zXyi|&%@Uk+Vn*#E$2&=vT|DSgPJ(=ojWUPYqCvPP^K391Xs;>1@4~iJ`oNnFX5^4TA8DSfeLj)sY`4Az>P+(54Cx@?7{oD)+#dCi$+6dP@F?6R`)ar-eRse@ z01D%~zA*~bLdrWAnTjNTuJQZ9B;5S6dq%FxNt2*{uO=CUTi`(VOACtzJ2z* z+By2|bxq^+>~V##i|uU}-&(n$>b6!1nHJuL1m_@3W#Stqv#i+}Q@jHjUBdLH#`F!} zZcguoDND}(<+%9u@(=uf$uf-Zhqsq=|9#-sz`wx1{}5aU_JOB@JK)zp0Dc#UAHQ>e z@bI_5v%e5*1}B4`!=H=Z{#U@q!4a?@Tn0wKQ^6C!Bfx*ctKSK328X~bcotX#9s#}w zpZ+xg5KK*_GU;Z(0E!Y9h0;Awx z;Kx4=-T`L7#b6vf3Vaj#e;8;xNE>L6w*RbhKh(i-p_v{P`j&CI+Pbt@E!+9JU*aY5 zD(RbbgkADZ`+iZh7Oi4U%wccpUQJRw2d8>e~;wTsj-YBTV@D||NpJ6R>N8kN^+X^ z!HFWpEF8}3SR0l!@VABYv9)WD1jV3+ZW>PO#4x9c%QsaelgRwWMV$V$5%h-Lpq7%g ze0=KWxA}J}ufjrIyk%$7X^59zb}QV!=c;>HCVU?(sIyCHp=!G@ORKqrrt;|=Syhk@ z$3ts3rn`*BUWiZQjkx-EO9U#_!N*(t5Y3xc1KqE)Ieq5J0~qHhL3Hul-knPMG)Qg+zeX@2dZ1BVlq-r6iTOi6e16PcoA^sp+Yc^a7KT+u@tot7MNWv%U|& zZ3sB?=W_SzG)Gm4qT%MTJ%CHVYtU--prfxau25E0tsRo{t(+^Be6JF=G@Ux<4H(25 zpF`}h;VqH5Fng%AgA^R~e~BtvbNX)MU}q)~XyxREd}#HMvW!BdY&e>na3?dMvefP! zT(T+8C_0-=Wh2r|+=|M54EU&HTNY?UyMQ8G$K{3-G+A}>v}CIl&myo4lpZrjm)Vla zjl_1QkS>zcKj$))?`2iTbgUlzZI_O7h=IBok zCE0#0%04$KPEG>JCo^lv+SL`6#4 zA+);~NjCd7F;{ZlnUS!H!kTYacdWrYUT<7pPG=?AC1bjX!o1VAN@Xvh(krb`Uf#*a zzE=TOl48GL)_ertE6ly79{qc9pU9l9nt`Hs+PHh!ml9pAO#=s1fuvX)#491JSvTMh zV}cE=-M^aJ`hSzUinMz{3i^<2X%`H6RJ4}-`;g4+JJQV}Qucm9ZPRE?j;49s3L@$M zird!@xT9{_oNrX$orQdQ)HlGrM3ej3=T@x9--TW!T$5~%v}bB(;uoS=Hed+z-SMVT z7eaK-e#>|x&Fr2z!w6!3-HoE zsX%-LegV86h<(5x0XhFa0Xz~&yZ%_()_;u5RMc4Ha1g6yd@HvL>jr%1QcravnlI=k z8<*lvNk@! zeeyhF88K9Zy)9)J8vr>p^cCoPq5-g0KEcvW@q~gsuQiMu~6UCZplI54x`Jit`=Q|L>5ML zQ#$iEX6T_Zl}zB`rQQ9e|i#ol8U{u0g_AqvV5TrzE!4q00c zo1v--60_w)B@?ui1T72<536Jg+18G8y`(JFAV~qF5Wn1gr=%#Kbk)o=k?s}74o`B* zY>q$Ld*pBNG+*C)WFV5BkaI!9@2Nc3Vc^MlI!eNebjig%67`Rk7Xl?ZyRw)fuaH#A z(=Kx8JE$lAI(@+RBUdAFvr zF_>=`B&qao>F``e-AUwJj^Vd%p86fnxIlf5Y1}JiAm?L5C6L+c%qx(PdJRS;M(P8r zlH|&3Hj12sG2rT@O3nGe*n{_%wzuz0dNq@-Tu4u|YD`Ye#eyx7$=ip7^23R`a0?C+ ze7O;JRz3S0&M9suDm+dUWSY%49+uQB33K)y$)O>+^HBPl+&gCFj5bRtk~y7R$12vX zRS+)9=5^=J9d`V>-m9~}nc9sL&y@3z2$g2hdY`YzIop#>JN9^%5V|hqL^|o=⪙M z__JgqFLJHo0xe(=sU43y1L}+@Z&9s^HXL8eOb;hl@7D+=j$BHO#%zI|mlQumrz!E9 z#M+371RnAa>qbm}IJvohLOYAeo~X4|2UMk_DvdH{T2nNQz0PoGVByr=vMR*=3S-UK z5a!R6n;0e@b`OgBKnkVN++#A;{;`MaPUN~#YaShD?}LO-oogby2W6|^Bpl9xgQ-@A zSi6y$Gy7_uvZL_+m+diwi%qV0XEvS4Z6uBiYrLVu-EOtAp)g>)KA8H5^IhwuMmvG) zrW#a(glbppkaR(pV^ueDp+zqlLY{(F^)_&CKSb!zylakjIGeUOz=QR6K}7!~lm zqULLF50SN1H+h~TW~bE2xCzKT0dECYfoUN20AdRu?fX#B|6_Wp|5sPN`*nU}#&qBPn6i+_5hwSnH#L9G z!VolDjZ-6v&5gP{R7L>_ees?odQ;H7#X>1)@!eb{`aO2SF^XHO zpI2Xf+D4O)eKNHAwDo}ki+QRIWP(Yf=ug@6yFO-=bb3?h&BYZ(%eBfRxhmpgk`h{& zS1#!c$+-v+4T*r_8rr0jPKHBC;{XezUc^{zzmPTdKGOK}WoDS=FgFd}n+m0rEvCK= zQdAz=G2QZ`XdRhY9S*E4=*J8@Y(MZ+fC}t@;cFs3R<&1Dh0aDyP)uY={2D=3N%N5{ z=)8#)CP$;q-HzpCd&?E}#{-Q+{rZu)c)+%kZ=-&#V)`Pyc+2(X)!fBxJ*TXiKdutl zL|(TxRfg4Qa>dSTit9N-8uq|V+Pm8wdo}7j1vwplicD7nj)@0eMt%W#a#Tn zmXy+7l`SZHdQ96F!ZDN(q^79Rr|9Vei6y4GSfak@GtGu6v(P>#bP9S5UEqRw1k`~j zUSlOobzq7`^yINSj0s-5M;bp_Yur;JZ^eUx_A|Z7!@U?Zij{-W83WNOOR$>Nc|sk{ zKQSwX>DQV^I2Tb-6IAt3OQM+C9^12e*J^7Zpc9q^nygiuO#}gU(b!nDj{A}TwM3Q6 zUGG-24lk1+dvqJZ&JdTa9hX55X{fQTUX@LiG7=oge3gm98^cP<@?9?!=r)7V#&si{ z7Nz{q=$VwAUNrM&*85AK>b?y1TrG=~$4|oYOG%xgpQC`1(^gbM%4a@8nz!`3dFkjv zxYRXYv%6dh5hZ$dW?6_Z50l#ayEY-{P-8Z8^KG#A#gr89()|0ya85&Ja{s=2Y&pM;Ck>n@G{_h_C1ro zH-dZN$G-_)0EAC}E;tkX4Bq=k;BFvy`Mns3zrKe9xwG$TPytT?UxmJ}1w!Ymh1UH( z(zM~|$Y^vs$1mu!n~iSPoz-b<%Hvv~=AKC)aCMPMBYd-Wv19`G8JT3lH@(S7_2&Bw z1f`xYj6Pq}zv{UKBe|EU$f-If@iQcogCi=bggwpj8;0zux+c+A^^YGGT9fRhj%LB? zBAPpn@d|a*PD?d$9K ztcKy7eKj4oY_<`)2edq%ug)UEACp+6)<<}gk<9cJnV1$yiPn7a=s@m-Qy;2Mnwq+s z7D-hk-g4QjQH3g9n&%8#{%gXv84@*w3(nXeDX$;Plh7LMUM%}V!?R!=#pbmL9BkD-3ph14cv!EPk+C9Y|QR9qev7mVv?jMVh? Fe*h0J^o9Tc literal 0 HcmV?d00001 diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/bitfit.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit.py similarity index 100% rename from egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/bitfit.py rename to egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit.py diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py new file mode 100755 index 000000000..849a061c1 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py @@ -0,0 +1,1567 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_ctc/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_ctc/exp \ + --full-libri 1 \ + --max-duration 550 + +# For d2v-T training: +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless_d2v_v2/train.py \ + --wandb true \ + --input-strategy AudioSamples \ + --enable-spec-aug False \ + --multi-optim True \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --full-libri 0 \ + --exp-dir ./pruned_transducer_stateless_d2v_v2/$1 \ + --max-duration 250 \ + --freeze-finetune-updates 2000 \ + --use-fp16 1 \ + --peak-enc-lr 0.001 \ + --peak-dec-lr 0.05 \ + --accum-grads 1 \ + --encoder-type d2v \ + --additional-block True \ + --encoder-dim 768 \ + --decoder-dim 768 \ + --joiner-dim 768 \ + --prune-range 20 \ + --context-size 2 \ + --ctc-loss-scale 0.2 + +""" + + +import random +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer +from data2vec_encoder import FairSeqData2VecEncoder + +from icefall import diagnostics +from icefall.checkpoint import remove_checkpoints +from icefall.checkpoint import update_averaged_model +from checkpoint import ( + save_checkpoint as save_checkpoint_impl, + save_checkpoint_with_global_batch_idx, + load_checkpoint +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, + save_args, +) + +import wandb + +#from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + model.encoder.num_updates = int(batch_count) + + +def add_adapter_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--add-adapter", + type=str2bool, + default=False, + help="add adapter to rep model's encoder" + ) + + parser.add_argument( + "--adapter-lr", + type=float, + default=0.0001, + help="adapter learning rate" + ) + + parser.add_argument( + "--gender", + type=str, + default='male', + help="select gender" + ) + + +def add_rep_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--wandb", + type=str2bool, + default=True, + help="Use wandb for MLOps", + ) + parser.add_argument( + "--hpo", + type=str2bool, + default=False, + help="Use small db for HPO", + ) + + parser.add_argument( + "--accum-grads", + type=int, + default=1, + help="accum-grad num.", + ) + + parser.add_argument( + "--multi-optim", + type=str2bool, + default=True, + help="use sperate optimizer (enc / dec)", + ) + + parser.add_argument( + "--peak-enc-lr", + type=float, + default=0.0001, + help="The initial learning rate. This value should not need to be changed.", + ) + + parser.add_argument( + "--peak-dec-lr", + type=float, + default=0.001, + help="The initial learning rate. This value should not need to be changed.", + ) + + parser.add_argument( + "--encoder-type", + type=str, + default='d2v', + help="Type of encoder (e.g. conformer, w2v, d2v...", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=768, + help="encoder embedding dimension", + ) + + parser.add_argument( + "--freeze-finetune-updates", + type=int, + default=0 + ) + + parser.add_argument( + "--additional-block", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--decode-interval", + type=int, + default=200, + help="decode interval", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=768, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=768, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--prompt", + type=str2bool, + default=False, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--num-updates", + type=int, + default=5000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=200, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=10, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_rep_arguments(parser) + add_adapter_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 5, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 320, # not passed in, this is fixed. + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, + "warm_step": 0, + #"warm_step": 4000, + #"warm_step": 3000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + if params.encoder_type == 'd2v': + encoder = FairSeqData2VecEncoder( + input_size=params.encoder_dim, + w2v_url='None', + output_size=params.encoder_dim, + freeze_finetune_updates=params.freeze_finetune_updates, + additional_block=params.additional_block, + ) + else: + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim if params.encoder_type == 'd2v' else int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim if params.encoder_type == 'd2v' else int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + prompt=params.prompt, + sid=params.spk_id, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + elif params.add_adapter: + filename = params.exp_dir / f"../d2v-base-T.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + strict=True if not params.add_adapter else False, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + params.batch_idx_train = 0 + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + decode: bool = False, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 2 or feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + + if feature.ndim == 2: + feature_lens = [] + for supervision in supervisions['cut']: + try: feature_lens.append(supervision.tracks[0].cut.recording.num_samples) + except: feature_lens.append(supervision.recording.num_samples) + feature_lens = torch.tensor(feature_lens) + + elif feature.ndim == 3: + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_output = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + info = MetricsTracker() + + if params.ctc_loss_scale > 0: + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + assert loss.requires_grad == is_training + + if decode: + model.eval() + with torch.no_grad(): + try: + hypos = model.module.decode( + x=feature, + x_lens=feature_lens, + y=y, + sp=sp + ) + except: + hypos = model.decode( + x=feature, + x_lens=feature_lens, + y=y, + sp=sp + ) + + logging.info(f'ref: {batch["supervisions"]["text"][0]}') + logging.info(f'hyp: {" ".join(hypos[0])}') + model.train() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["utterances"] = feature.size(0) + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["utterances"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer or [torch.optim.Optimizer, torch.optim.Optimizer], + scheduler: LRSchedulerType or [LRSchedulerType, LRSchedulerType], + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, + wb = None, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + if params.multi_optim: + optimizer_enc, optimizer_dec = optimizer[0], optimizer[1] + scheduler_enc, scheduler_dec = scheduler[0], scheduler[1] + + for batch_idx, batch in enumerate(train_dl): + if params.batch_idx_train > params.num_updates: + break + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + if batch_idx % params.accum_grads == 0: params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + decode = True if batch_idx % params.decode_interval == 0 else False, + ) + + try: loss_info.reduce(loss.device) + except: pass + + numel = params.world_size / (params.accum_grads * loss_info["utterances"]) + loss *= numel ## normalize loss over utts(batch size) + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + + if params.multi_optim and (batch_idx+1) % params.accum_grads == 0: + set_batch_count(model, params.batch_idx_train) + scheduler_enc.step_batch(params.batch_idx_train) + scheduler_dec.step_batch(params.batch_idx_train) + scaler.step(optimizer_enc) + scaler.step(optimizer_dec) + scaler.update() + optimizer_enc.zero_grad() + optimizer_dec.zero_grad() + elif not params.multi_optim and (batch_idx+1) % params.accum_grads == 0: + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + ''' + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + ''' + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + ''' + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + ''' + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + ''' + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + ''' + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + wb.log({"valid/loss": 10000}) + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + #if params.batch_idx_train > 4000 and loss > 300 and params.wandb: + # wb.log({"valid/loss": 10000}) + # raise RuntimeError( + # f"divergence... exiting: loss={loss}" + # ) + + if batch_idx % (params.log_interval*params.accum_grads) == 0: + #for n, p in model.named_parameters(): + # if 'adapter' in n: + # print(p) + if params.multi_optim: + cur_enc_lr = scheduler_enc.get_last_lr()[0] + cur_dec_lr = scheduler_dec.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"enc_lr: {cur_enc_lr:.2e}, " + f"dec_lr: {cur_dec_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + else: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + if params.multi_optim: + tb_writer.add_scalar( + "train/enc_learning_rate", cur_enc_lr, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/dec_learning_rate", cur_dec_lr, params.batch_idx_train + ) + + else: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if wb is not None and rank == 0: + wb.log({"train/loss": loss_info["loss"]*numel}) + wb.log({"train/simple_loss": loss_info["simple_loss"]*numel}) + wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel}) + wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel}) + + ''' + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + if wb is not None and rank == 0: + numel = 1 / (params.accum_grads * valid_info["utterances"]) + #wb.log({"valid/loss": valid_info["loss"]*numel}) + wb.log({"valid/loss": numel*(valid_info["simple_loss"] + +valid_info["pruned_loss"] + +valid_info["ctc_loss"] + )}) + wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel}) + wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel}) + wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel}) + ''' + loss_value = tot_loss["loss"] / tot_loss["utterances"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args, wb=None): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + #model_avg = copy.deepcopy(model).to(torch.float64) + model_avg = None + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + adapter_names = [] + adapter_param = [] + for n, p in model.named_parameters(): + if 'q_proj.bias' in n or 'fc1.bias' in n: + adapter_names.append(n) + adapter_param.append(p) + else: + p.requires_grad = False + + ''' + if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n: + adapter_names.append(n) + adapter_param.append(p) + elif 'joiner' in n or 'simple' in n or 'ctc' in n: + p.requires_grad = True + else: + p.requires_grad = False + ''' + optimizer_adapter = ScaledAdam( + adapter_param, + lr=params.adapter_lr, + clipping_scale=5.0, + parameters_names=[adapter_names], + ) + + #for n, p in model.named_parameters(): + # p.requires_grad = False + + #prompt = torch.randn((100, 512), requires_grad=True) + #optimizer_adapter = ScaledAdam( + # [model.prompt], + # lr=params.adapter_lr, + # clipping_scale=5.0, + # parameters_names=['P'], + #) + + scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs) + optimizer, scheduler = optimizer_adapter, scheduler_adapter + + librispeech = LibriSpeechAsrDataModule(args) + + ''' + if params.hpo: + train_cuts = librispeech.train_clean_10_cuts(option=params.gender) + else: + train_cuts = librispeech.train_clean_100_cuts(option=params.gender) + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts(option=params.gender) + train_cuts += librispeech.train_other_500_cuts(option=params.gender) + ''' + + #train_cuts = librispeech.train_clean_10_cuts(option='male') + #train_cuts = librispeech.test_clean_user(option='big') + train_cuts = librispeech.vox_cuts(option=params.spk_id) + + def remove_short_and_long_utt(c: Cut): + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + #train_dl = librispeech.test_dataloaders( + # train_cuts + #) + + ''' + print('\n'*5) + print('-'*30) + for batch in train_dl: + print(batch) + print('-'*30) + print('\n'*5) + exit() + ''' + + valid_cuts = librispeech.dev_clean_cuts(option=params.gender) + valid_cuts += librispeech.dev_other_cuts(option=params.gender) + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"update num : {params.batch_idx_train}") + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + wb=wb, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + ''' + if epoch % 10 == 0: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + ''' + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + if args.wandb: args.exp_dir = args.exp_dir + str(random.randint(0,400)) + args.exp_dir = Path(args.exp_dir) + + logging.info("save arguments to config.yaml...") + save_args(args) + + if args.wandb: wb = wandb.init(project="d2v-adapter", entity="dohe0342", config=vars(args)) + else: wb = None + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run if not args.add_adapter else run_adapter, + args=(world_size, args, wb), + nprocs=world_size, + join=True + ) + else: + if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb) + else: run(rank=0, world_size=1, args=args, wb=wb) + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main()