From 11d304beb0cd0cb9e635ad26cccf13e11f0d74ad Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sun, 8 Jan 2023 17:48:59 +0900 Subject: [PATCH] from local --- .../ASR/incremental_transf/.train.py.swp | Bin 57344 -> 0 bytes .../ASR/incremental_transf/identity_train.py | 1195 +++++++++++++++++ 2 files changed, 1195 insertions(+) delete mode 100644 egs/librispeech/ASR/incremental_transf/.train.py.swp create mode 100755 egs/librispeech/ASR/incremental_transf/identity_train.py diff --git a/egs/librispeech/ASR/incremental_transf/.train.py.swp b/egs/librispeech/ASR/incremental_transf/.train.py.swp deleted file mode 100644 index 41ac6eb6760eb7bef2a0e6f2d2cc8d458fbb8a0b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 57344 zcmeI533y~zb?3_ji~*ZKoItP{c&-jmGE!-2x3Oi5eC<}-ZftjVEVYf1n^LLjNh+&U zRjyan-Ll;TVqky(&VnJ60b(E!62fAZ0AVo$1n?KYV3v>#lLasXCIP|_Lk8w|&b{~D z`>IOnRwvBln@{!izg6|#eRnzc+~sWdEsR~Y=Zfh3;#iK)BXYU;*Z#EfbElsAz*qj_ zNx5osIUd`7&V=0B=rc{axLX^!qFlMC*<9J*ZWim!O1a+q%(GmEqlcUAm84a!#G|!J zyjZT+N8_brv|d|i*OFEoSC&V2&RjWKYgF3tYTW3Q>!nV++(;Hj`BQ5YTWh^pZ~fk) zz!n8IqCmH^IC9=;x$P6Ji~^n3skvXJO%Jf`lJD|4$^1$GKFI!lf&KgJp8F@;-!1$16q|v|=fU=R&Hg>o zbN>{3{U-bO?|SY(#9n`<{d>?>*yaB)d;NO*_kBJ0Pqo)$`*-+f%oq3k1MT$-?BCz) zxqq6ye!l(t$)5WUwbxzwIh&!&-zLww{IBo1Zy(?Kvqga|3T#neivn8|*rLD|1-2-# zMS(2}Y*Ao~0$UXLsZgL?&8evRSdo{p|2z5r8$X-Ny%x-aM}xZ`p38j(yb%0Xup2x8 zdJX)4EQtx_Y1%hsDNw19M}QQ0S^T~KxqF5a2L1@yaXHphru=A>7WSC z1P=z^L&*Mn@D6Yr_;v6i@Or_%irB_&9hIcr~~Y#9$me z5_}&;#9xCCffs`cxEP!b9tpmQBI7OKCU7H|1HTA9ivr|r;CbLWunf)xj|AUGdGhz* z)8OskHt<{EH^41m3S0;-04ITa(G7e8{0Vq1SOrU95nKVrz~jIv;KL{u-VfdfUJGsk zN5OR<0%w57fE@Tbef$ORA#f8o1iIi_unUX>XRQ@a z)fydtB|kD!jVsM++#adcR`cdYtK3fFcClQomdfp=u6T14s2QIj{a?N<$Mx1!erKn{ z&04b&@pRlsl1$8i-}9WJ-40UfFECh0i7WBa+ss36@m zvgnK3bX7eu*XiG+IH-pwj@7}9?&^r^WaX1&qpHC;k#OfKj9X86G$|Oyos8gG{aV~< z*>}Hk^w1=Ob9!oPF$kK?ebD=iEbMy&45sql~$8!rF6K~ zSuQQrn+rI9Y3oY0>XDwR`{b_tp8EGEP-+F2%SqI8sRSwN*%k)rEuHGkay6*!aD<~; zGt+=1lZC5awN_qgG?NY;#C$H@V0cA4p~wW7G}{YJ8N9ZjneOUpxxIFEyVi-@;S&wq z;L2@W=i<6cI}M>tH&>_2oidDSZqvr;>HXoum2S0MEU&iQlb5v1)fqhYGDWzXvsx?j zxU#IfvAY^KhNInwVslSNkh=%_HIiz#;&QAnH#{CprWxdu|DlYKTWxlK1nUU)~nOC|gACkgkcOA^cjBTm;H2r1n*tN>tnl?+9x#uS& z``|9O0-5{eq+OC4(qsm|%i?nJUz}0@#hpo8y5I`EpR<{mu5L6UI;7FaUR`B@t+m?C zN}MD#SCXu@GD+NzgjJI1EpFA~N-XJ;)~dUw-Q#|&;J9g6Y^^0(+FI;vtPMvx4@CQz zp_+|yeK?x!Ld{__eVK=nW#LKpP^H;gTddV>Pq)gQWt=tL&Hemzg|hH8C@gWmW!B=i z`ziUj?`~#gBF*RXvMCrDSuGzKsdn3C-3RR0F{T$6yY>2rI)_H%(Y1OhKGI?VZ%3_m zx4}H3-l58ErKH0QQ-_K>MoDCocawNzu{C~f`aG%HosqcJtSqP3bls3ENwj^;K8T}g zy*iTAZiu6ExKte7ut4XXO3K`}=;CHOTCE+Sj&Y@?+ZnwNHCQ|P8oPb$KdP}m{kB0^ zSuA!PnBG~sYR}A`i!R??n%;fYo?W|Vrt)LMWx%ZmrvORmJhohK9 zMs}4)=%XZwD3L)@uf&=;6bMi)DPLm*Zlyy@b=WQZO049P;m-**!9WFSk5$O|u)Vme-=@LIVam)qk*v&NPYJDTe${|bJGGAP z&Tw>9+}1%r{i1^GM)UTe{LlqjVPSw`aWwsB8>Ea-sh8Wr1ZHAr$4e}cP$@FgCN(X0 zJIl>>k`hZPwyFP~5jU0&$F=C1a&yR1d{fVB;wz{WGbfxHUEXaDpLj-BQGP;mL(NWdi^ z0;hw|ApgG+JP%aC9#8-e0QVsGe+t|VUIyA=2K*9u5cn{1{!wr-_*w8JWcQB&k?}>g zPe2nzMXvuGcsqCncsAG$hQPnkwjY5Xg1-YI+us3xA4t1@2W)+|D6mC= zEedQ=V2c7<6!=d-0ol;3vT-d85<&mR|OpT_Xg!qPX z3N1;qRq8CaMJrnvAL>aE+%Q^bDFzv`fF5FHYHT=Cfv+S2-_-c9DF2cv`Om23;(CHS zBag^uwN)gB@i5PJQG;mGf`8mx85)inBGo5T;T=&*OwG2tvB{S9`Mc3mqj6J4IYD#Y z$=sTX#&c41kbFu@F8V8`T~eJQat)feHhM`B(F@@Ja?w9;U>>8lkEEo-^RJFXCOqo9j zNv)}*Wl9e7=)WR&7xRUw8!O0Hr9d!{RblWI%=X9L+)EeMcH2S;#;NLZlg;YP<^YkDaP$~vt7s#1zydI z+@H-fFCrq+av?@9?&xNP$K5JJLS)fZ1;a;@slJM_)-Hl(X$du?Wj&XPPmy%uR?p3P zGf7YkR#wE^(Q{8S7|aGFUhTG~%paqCcBQ$$wM0}(HPK}0I31iyS#l~+y}QTVU**oh z&#JO?O8+06Aj{}Em!;qt4u5= zOM8)2emIi*Vhw{vOH`n!bx@KU&3~K}(qdH~9Jfu+`u@$&lpD3qTBMqK^oW=qP|IT3 zX{0I-TclLU>HQ8jyY*^RDT~r$-@YsC_;=hdmX3m%tyGPg)l7?~NR(jfRXJx{YS?Tt zTEt{vjish*C!-##jC5_%niPF*amFfn_luI3fO=62H?8p}9TJn)c%T=G`^4NNW6AK7 zqnJx%EDX|XMbS)OjsZiGiv0f+Bo z3b+nTgC~L~fbSyze-!)6f5juPU zyb2_s49)=$0}lnCM^Ezmpbi#64EBTb!K1-9u=U>y{s{absDcU@13$v9e+T#=co%pL zxB=_|7lG5jXR+sh2D}yA2x{O0@I!3(_k&M>kApXZL!b`M0o%bTKw=(l2X!zB#=yNc z0Kv^0MlYBwE9VQ}qi)l~8b;l^yUIvfjFXOP(sdLj9JnCbNtSdS=#tkq|Lu6Q2%gx* zG{lN98O@@8WuT(HWqB(nEE1HHpuewl&~%|Gl>5bKKW0Vh%Aod{_;1l=aomcUjrv;D z#VpT~$8a(0T)O@)#GS)&+(6l%u#B>X%7Q#GW`8A`ne0-m9S;JvJ8E97HycYOvMp35 zBdkh{>a<`p#B`pULcg6JUQx6+#1>i_9~&!461zNQw~)uk0rOo6ZQ^2>0h5}UKZ;@& zDQU5`PC*rzOV``H={9i4x5AqwdC31rinutSWo84Jnk)-uy^>|y?ua8kLvTYxRhx&g zk*uPcU0CZte{}Ov(p?zg8Ej-J(d3DVv(KKGNUun4Jw+Rn<8)fG%x&oS&GP0c&}xgN zIH+LIIo{~5wpuj{k+ecbrbgyuE}tgLAM~q9W(CO0R4MzmG$}hBY)sIq(#hL}$S;(_ zSL>A^j4n#{z0HOvh&&KfLzV5mhAEYTJoOq>DELl$NsbcI?)sN&8z% z(Lrr=B<0iOx)ls=+b;3Mn~KD4Tu2`ncKnQ|M7k}oEY2letKl|nbovndg5BBhquRk& zV+KBhY&fd7-JXX1Tc?KmK%MPEN=Id39>*f?*tr@X!ia)@K{7Qba^Sk0a6q@*JZZZf zLU_5$3vOCOvC}M=jBVkby|P=oNR!*Gp!tD=#FDun-?ZuhMX9`YvcFWSEs_u+MYeQi zkXYf-(os!V9`31uysV5 ztD2@{2JvFvdmCxlh62ho2y>!4P?8+AJ)2&jDoJ;>FxM(_DXtfW)EtH5M1#gl+ssF2 zqC&&aMmEA|1Z*PQii)DVU}Gz?tL0Ne2J@I!$a0CPoDrW{9jp(9@lFz-BwA|<$uuv& zb3ueW)=m)&nVU2l+LDv%WsPUY;aNJ=?1RjjywHjb(e| zOcR*yNLpItdOLS&sieqs#bSP1XC4^P%oL9Mq&RbhOpS)i{4C@}kjvw4L^Dmnj4hKw zMLYyoKM`07WWAhPUB@6&UT2{sre_xIVDR$R%URHBB0+T-V{>xWL0 zHjUrf35#A_PW6uVzVV|g9717dwsWR><}O)UotJerJE`~h%VW8;rw#eAxe9h#TnJh|9yR3_&|?>O(D-h(@X zh#RKqPAnKx4xIO$J2E~GODxy$52_jI06j5-&a`AFI#TmZ2&-~~)gfp3NcdbN21}5X z2Np!WG82ucZRj!EvHscyj~Mku`V4wJv$X3YjIO@Cq6^)%o`NS;{vW1-%gBl%|G$7g z^7$C@{fEFE;0@rV;3gpYfbHPB$oY4I_k%wH?*gv`M?e`o4g3f>|8DSS;12LXAaMfX z7tjD#gOkBmkn=wW#4q5J;2q#ra5E5{z~jJ&koDgI-VFX2{5CiQs^C|^{mAtmV?Yjk6CJ@{g4cuRfL-7;kOMzNPjEl@3b+%z8@vE4fNQ}VxCopL9tD1azTgkR z>wv@!ybxRu4uVU-zrhb50}?B62Y3Ux1>6XPSC+tLK4wPne)_%6<$-!kk`hZ$ZyW!- z4ZXLO<;~RWRA53S#>U2c0UC!eu!UCO*~me?r-`Pr zGL2gCCR-C6fo8>h%S>-p&m#nASQoPe!YPyx)j`hns9eUjAG6}NWV&e66&FL4bcT@d zqW!z=M%9<-2$&a5<*2y)m1L*2V{B}czeTgJ^~sg87#{o;IERFyb%^Vms9lC3#P<#D zdR$y8qM`P}IKpKU&%C%(5$hfdhTKur{t@P5WwE}BXg{QV?1*3;wYjL8i-js(BU$k3 zFphuPF`@J_>|&uSIl?-ov8ROpw+O;eo(UhdbcBJSXWxWwe zCMzxZL|f|`q|2^ri}^90X$`HwO_7Z#biwR~{TMm4tzJTsoHGDgy5esUXK zc(L{}iE~0KO(y~?qb~c*ehtzp=+-=x zZQNemZy5|bUCQc?;bV=(bg?6Pj!t}pzl>tI2{}$f5mh^aVPTn{NF9|fGq?yTX)NF% z4m(yC${3JBM@;UUAY5k2!r6d&Uf-3-5E{`~Vwpv*b2fgR0L9n5i*W}UvWYR4)T)`Z z3rtqV=3!Jo;s7(`rxrt`e14bchl1>_`X-&ThV6Hj}1`Hf7vpXq^(?^WzkO&{?R{6NDJ?}AhQ1g*bk<_H<9nZ3jPuNIe0s`4ZH+g2p$0* z3LXOf85#etz{kMbz>VNCup3N+Cxc%ACxLGv>x=Av3wSB$f~&wqUWYufY|;o1V2VLzXLoM>;e+^FE;;C@F?&@3PdKq8e9dYKps3C+>Z?YUT_1r2>gUTydOwk{y_S3qmK;L zU=<{`C?UA0+&)q}q;W;9>f-3c#P~P@;W_7z75S;V7F@^>Y+^!+>Pc3Lo{3`(QIc$y z2odTo6!BLZU6~kJOv?NjB~S~eyV3J^?8xlo8L%?)2>L@fPp1>Zp*pd~1dc_?+Ui1+ zIAx@6;EKCNF-l>aUFqOA05#Q zJ8PSmpiCWm7M1049jT`;HgM1$;<%?$+N^Poo2Xl6l@zAX8B5K7XGaC&b}6CPMqL?= zL|I$gP&8^hynLFWK-+p4W@%2Mnl;h&YEv|GDw$_<8jw-+gFVt7GRm9DjEhyW=aTODaXkNG)O)yxoVR}Vn<}_ z7UYg1Vo{7XK6E&z^x{9q*?6G%LIz5NmOgN{m`_iXdC-4Uo2WiS+hnw6fyit9aijs! zyb+#{965q6AYc}|Mp=2=SO`QFBM&9fe4tS1dlb0<8Bm6Q1q}jQ6s8R5zFjw@viJkF zPj3(Fl#nPtU z8D+mo~3eRYO;Nq}KU& zEi{@Ro72=aY|~Yqg=54(s=Ep-Yn%ds6_T?W%4keQp@Yh$>TbtOZ-n>P7VyJQSKf$E zO5llOOmR}wChC~CLC#D9>B@sDAP^)}^CetgV|#(Z2~TvJztk)5#28jN$(^QT53*<)y$Nf98~{L=fzI?o z`GDDd1Jo{ISjvp+2Ky;Vt*Y=<{9%nDO;v4nOJA_swxNFh>gBH9Ia7-pQ4T&Ku5x-l zRaIpr&WkK1#bamFb^IH9sxd%41Xh#GykpzBGb#6e`3Rb4-k3&I&c>F z9y0oO!JEM!gX=&IoC(eV|A5>sXZT5s|GU5oz#Nb`htt7#kl8;AZv7 zTdOXWzv5F_Em)W=xa8AW%~_Z(xMb4Vnx4VrRr9uB6^5+mR8MHkP&D#X4otG% zp-zQwU6^cza#X7NG)9;??E03^H5;KHx0d)bE;`AC`_UU_#1+I+V>;8$s`egUq z;Jf~5sfgn#bEuztvj2{M3jQtyM-;k8dM^9)dgh`_r9b>Vds>~l>ncNDJ`M2>?!dCr zA{+2#>Wkig$3G=4Sj26&ixZs7-u{*$WfI!*;GJFWEHdLvyDggCH@$CC9A3m;#lDa& zzD8EE>0seb*dn3iaGq}I(<%tevUnyUxFP+pim4x}l!XJkgq*aOq9r^Aw3V8UU~6;_G>0I7 z&?`qW+vYd&#$$1Bv%H$K#ps`We;-j?V84@8`JiLCdtRVgk=uQxFzHzG%5C?_qoNF) zK*LsDOzlOQMr*Sy@iX|T4y(=hJnJQz<^i`5V=aPGo9xJAB2@H9G6aXEFu{(V+U9#? zZSv?#A)bn+wz|(C$w=T}prr&(g&U_a`RMM^21FL}Snl(^ogQSkVd5KeQoi9g21X{H z(hLr?o0>+_jYKIMY2cavsVpP>5CkN9pUn}`1C`taZ6=?|Xm6_YbF;eMvFb=SsAQ{! z>Jxo+Pz^W07~7N4rL`*IFk-3ETIhjRBXabsH!B%nYR23P z333sd1RwUlW{+aK53RW{9LNau@Yi2bE<}1~svz%bX*n66z%ETMY@2 z<$h{@CFfM>MJpaJdYv-*pKNl`>y3rEWyir}gda_e!9Wh1(}Ole9hfpY*=5uBWfzVa zs38aTE{|AlnbQD|tdRaRXg#Ft*x3w%Qrs6Eal zNgj5hjNE|RsF`9oob#^|jEw@K6JsRFi5nTL*} zMZ1r1jz`0cXjIk`yC>%~s?#L1Ml^3G?s*&SBc~AR21c~e<%&eSODC>Ihic_Ct?9{K zoP&YJkTW@$Uu?FbtY8FJCv z2UF(G0nVF}8Cqu+V5zD&3zL~_L2QUZ>RQ0W>MCqXWnH?`gkY4JItg~RBVIYzIRN_H zi|OXfw2Ck^M54@k*CQ5bGuyU_>mF}Cr1Jmi$dQjnwiNl_y?^i9$oKbwZ-KXf=Ywa0 z^T1QU^Z}m-F9-4tfNO!+`}08T{a-{M@M-W?Am{u~ zga3^V;CtX6@MZ8Ca2vP@Tn7$-^MLsA{}{c%zk(lt{|P<~-VR;?Pyh#FxJbmcRjUF?cYzkG{Pdd;r`AUI^rjfPLUxAbtNhFyP~+A=SUO zr_3L}MYpGMOxeG|4U98-E;^6cZ~kT)#XFUL^EVBHTQSIAHgZ5FF1F`p?_ny=ayaF_ znDjE%gOj`Uv~ROlbm`bhno6YSCyjD+kSGF1KDaNIxCHN z4fmq6Of-g=W-$<%W$;-YDjD)|JR)le6^i zq^ve@jFp>gFZrNmf3muZYIbZWrj>RO{>Yqz&E~Fw2VvvJdz^^VOig z(l-+9b7uu)A4^AAP)@3J8<3QprM~PSO~a|;)B6va=GhL&E>Xj}iuDY@0<(_vjaBFA zQfwlnIU!$exR1$_6B{J{(*(TbrCQ+=bHvBmV7Rk&R<`jRwu(3|{q zyDhQumUSc_TZ5IpQdj%|NGp>4O1x&n1AQg(JsZWqd|+-V)9HRIb1Q2N&)kyzbg)0~ zyJvFLiOb%Tib1kDLs_q!Nhj$|PuD5e`11;mb6$4ifDjtn-_)$rKkBxJ2o!I`!l4=Ou&nNq3JCsU{xu^FEdxJx5LCLc}T)0)C(N?`w36t zD9qe%9=2>O`I;rmF6L&t8`=onijlICpSk79bW&RX$uvp!$^RE?jnvq(3^z;+S&` z`}e^D_$6>MxE=aSAG|^Oq5mVyEFVAYyU>E)|r>$7{{8Nf!KgeZk)1??*7 z#+*~v;vuPJkSHKhkjQWJr0(0q2a(;LIA2rkR10cOKNW#rswMnE>&@Z9tjD>q*>EVn z(dMN-;;J9iO)3=BLd-s%n4ZT>Mc7tI9SnI=V~obWn9%OAkeWO9PJ2m1G6>7bFvwhR z$J4ITF_jBN4KBzM%RRzC}_A}06KBy$AgI`An zop_zc30;FK3^%I+9gkW#=&au464OoR7ez$za}xb63U|+Ho&`O3rPmm)L5(~SqbG)a z=C9|ZL*x^*+TSHNQ*JFWbuI+SJS7L0*!OIZ%s=bDNi8(~h(qq!gm%{4ZYy&!ob;!n zbdv9}>Z7~ss#EGzzQMjRP#j1pW{ntlK6V{13|SUj+3t6C{FHlAx!>7TN-KH)-%^@g za4DO1Z7PV^k~$v2E2)u8Ig_;f$Ubqz-Sd|N0*gbuzBo?Uj&3x&teoU*9KTgvChiq- zBsq>pdRBG;K62BNI|p`+r&lJ5SC%%rjm;D2bK)g>V3kiI%D$se4{%Wc$JImt0>xnv$Boy)63kyxg9thXv*%JquM z6hR_eSZ`KNEn|EIiVg8$AUtEcR*IgdqIn!L`CR+X)W))+P!&$IMZOd-f`;4Pm?6WSa`>R3$$VpAEt2OHlhYN;7r)hZ=vKsGxa4m4C#|NqqlM3P?PHg z8rEXtpiBTEKhjb(R9&t95l$Zu1%M_cq=vy(Hju(@G&s z*Xl-p<=(v`v$}_KK5S|IWuIOnhR2KrQ!C+E73cItopWZQ8IH2=`n8$}$BY^W!03gq z%_JChSL9zWYzV3PB9rcT4xu+hIzbbFVN>_Y0>5{XLuOU6o3tz5Y!_9N&QzdBb6Qlt zHV(K+yR90@t8!J(KBCO6@hGY3jw!_^4stTOFXl6QE43^H7J*|`SO=HtswFzDr^4Br z>sMIQZCGKG)3GY-ixBGn{}E?%U$gT6^;Y(L4>JC3K;(Xj`Tr3z{kOp9z}-M% z{jUYi#{U<{@#63QQg8*h9E^d}!55I{Uj!}#j|SgCp8p_t8+ZkHIrug39Iy*q1pXH? z{hPpZz!hK@5a0g@i0^+G!+$fsXMyp_3bMVN$tSk_ zPk~p0B`^lQ4>A4)+zws{o(|3iL*T*SJ;?d*1}_J<0`dPZ1JMCI37ibRitK+kcn5eH z5S#te!4B{^@blnw@NIMfq9YL7{oBARfanP-;39Aq$br8?KOkrDy%D?`yb9b5ehsXE zOMsm9KMIQA6d?Wi)l`=d@!5F9ie`zeU1QMOH%iQVrDk)rKy@T1`Gha}4n$tFAC^;9 z?VD2x=WMqL*tcmIY5t5&b7FQ*vde7l^V4if&k7G~^{gw21NN1l?qh7PeLf0u+xaP& zfx_EM`WzNCm9w-CvVuv1RCzwgaX}I&jU{$|kTon#(CeunH_Rl{=l0NVES1FWv%uLa zD>?T|{`!c_qK0Uqa7{xrgKqK@USFzUb~`Sr`|b>1ah0i5dbIFV`pxgkrPdqtoX*y~ zZ)VDU$v){ibjp3n-X8D*o>Wi&pQn>e-Qjs-)n=FpQRLIR^sq?iEbmiDs+6kwUM?YK zxm_?3sH2*k1EckCda$uB`98+lcQlNV<|bP{k*$rxLrzMPkE663Lq!pjH2Z@7k5448 zUDKrex53fK(UYwkY36~ne~O}%8=SH1!%SIh-&BH3k6iZ70cm4J8q(AWC5}76O)6;k zZsS)br~qgZ9{Z&UlGx@NSmBt-q%>&jEn)24+2-ZL++{BsBrrQrQ!@R230cpG?9EHI z)_HZ4;fg0c$%oYYevk@S!QN_(EpG+(Y!TkZQ>}1zB&{<)Nwxz9*)b zc*0zf7rhvo``^)%v(F;NmIz^%^X~j>VS?oSsdcS&skK_ZUkxE8CtO;M`%v)q59gQ{ zT+-(q`FomWF9*%?6HR~SJiLxS#lW-yYoppQRSleUwi%}!83dwwT_r36^8>P)D0^GY z%Cc6RE4k5m#w%I)i?_zKJ#?+vatE7?`d|v_0YA~bV@uPg=6WfvMYTP$U zH8bB0P{sc!BCFsXTv+$5LPMUw*2F48w!bLtwKr6T*TUqYd^{JJ-|~h-AC}T-uAq)W z2a9s7V15h@)Y#6tgeK9^Vl+svJBwb0%qrL}yjE&%4E+)})_KNPW>K|K9EPCgmmo-G zh&*SH`*C{$L(lUS4O28ltdkWlA*^>fb)JD@G<I<@Mcl%PNPB|?h zqj#Rg;ug?0*f<$N!U8cuyjuM70%cXwHxo$Uxv4ASU zojr}HIEpTYH52!E-e#dk;)r(UkZe~`W{Cr+iRxHv%3GYo!h+C@f=do6^1O&fk0onH zwgoamSyt_-N4B|k|3B-~rZdWP4wjku>$bWd5sEqj{a+EexQ zXI=}nO=&K#uO-5^p0Q{)ye&%ffF|G~6~U;ZYQzYyXsgCef}^k`nwPAp{U4D$_pKiJ zUt<0L0Neo9z(H_6cqI5Va=+O9F92r(dC%V`fSlQ!T$hrMv z;ECWz$oroIcZ2tVmxEiuGr=)nu9FTMR<^6uY z1(v{h;PK#8a1!`yYo?D*sb=`3RGDnjK?spX$*b zA|7f2-Ful=*jb4&RF`wgsb!$xBq`Coicqe$c$pdqd6ab$t$d(g-Xo?uy~Ye8AyDd@ z1v~Q+PA%C;_MYec@*x?DLbN|eqnYOPS}BWst@b=sntSoo=X-*_X+2SSIE zgaZhf+OjcmwqmJJjJKTp~j>F*xGMNZ}YD%|aSn8(1UE3Kd zW-v2!XJu=ntA{5Ix(2hUKAxC{UG9Y*Jc~v>UY(|9o*Qj>dZG8FyoLo&&}w2@6sn1oL+lpE z&FiM$g#wm21CxfecFQ003Utk2a9X2Z>9>C9BgjrZlb*K=nLd)+W~>DX#pl7SK684e z8SxvV!6n1Q&cZF#a>%j*oOq4Nt6&{kwI(iqWth|kcj+D00 z#H>wC9+ed&RrC#8vnx|U%+6ho*^8~xqFQ%4v6M_3?+s$Q*l!IacE7_c)jy5N`M}0% znYB|IzVQ9#VyNu2vWY={JR|qtHRc>uedBcVRIr;VJN8U(C{yKx^73si^qbN^rNHF` z9+r3O7KVzIR+siQ&qYjZio%)t~ImjG{U`Q#MI=b72$_4MqxdHoFzsE}8;De|~2WL_lK zM)UfzpY=_n-Bc=igs{D$H?ZqMpf0dy6%FVY;bxVMt!JKSv#150D{dPm_X`t45xe`O zi5sr3ob1M8f^@_VFOx$!dHO`c5J>ts*aP-bG|J z`syo|Cah?vp8Hu~QkCjRMd*d>4CrL7^_L!0@xk)qyH~~_*||RS(}MNaK09=9#NH81 zI%$gFn#%u=VK%!{WJTitWB$nJF68_-0e9YC%=bL_F7p0e;4R>{K^t5SevHijE$|ue zZ1Afn$S zcaiBo3SI={?7uwtd2l+o4|)FM;3!xD*MXhjeq{KsgImB$!5$#?{=Y+h|1fwxcqv!` z)8Jh2Snx??cX`j=O+e!EB{pCi5LsUO=M&OL>wSdRELWI~qU99%6KwYUPq@gf_s6*# zlOp8fCx7$X{GUn#yiZtbFlrWZjE}Zn>1M=uQRc#@<-18Q1E#2gd5MK+wcK9wT31QM zmM<+zc4*#2DRuzl%rTb*anBS$&YqA~C zoQz^YILgk4@)X^Z_OrxKtJ2tCd4gjGR~I8Qi|H{8wRYt@4((xxlmSv+lZ5MXT7MVy zUy%~~vxQUII5DoDbXLR{Pn%#a(F-eP#D7Um8Q?Ie9GYD zou0inh`swMyN+Lz5QC_g73Xv>F-j$^2>yF=GyI5CIf3ggt#y9PrnA<_z{HuIVK-GX zqpBI`btUz(X4BNp@4QB`HPVz$VYr@SjwDLeEbDn-NumvAK)J26=SEIiX`(>_x!hc; zHy3!%eLAZ-YNr>>0@vThm+GE2QB`VO<(OsST9#6&lEx#i?LdqYc5Eo`Zw zoSD-{?SJb$%~YNygW(H)ODfOxv&`{R8PA6HdX(OvD^1us{b){~u@SXt9Z!HRHY;2p zG^J!%Tfvnn73k_uTqN(&Nzc4W9b}|*W&HhGYR*sY844`uG z+9_!)3Y2}by9X#xrbL4xxynK%ncBS;f^(1NAm@lCf89#Ss)NIZHPi8)q;ZjOQKMM) z$ev$0G=!9wE{K`@aH%mHX!VNmheWzEv*dGEq1S4+HMI`OeKpJgm(+SPZ_gR7F1J94 zoznP`ObuNs?bku7QkP(zjj-I_a5F3kHkd=?w$Scas#HlQpCCA_(yf+@crg_vOo#B1 z&bn@S-sD6uR4{H0Gt2&m4i`rfy|7-`m+vgg20ddz+|0&vQpscMeVuKDIkH?ecJ3D= z3yb`JzLhV(kF5VM;N8Ia0PN)Zqrro~myr423PcZZ1Y8Ov?%%!h?_ZGp{}#Ltyac=$ z>;)1Va0+-Z_&hp*-vgo-SOC|6-QaQHK4kx|f!jeBbifRFJU9*9gUtVG@G4LT@_v9{ z0cU_mfG;8Yi*JCuAD{#dfJXp1@9*)$y0$UW=qQDjfwkWVgf&Xh1kj;yH zwnf))7PN5;Q`2OzI9g|qRUdUi4d%lnW>j7@=QHu&qRY_Q7_arfLZ(vWZ!Rc#)Q&cWAjsl2e( z!PyT2TAVyFarW606X_LXRUu$lw28rI_P3UzgWBjw%BRWmJq@s3;)ypE?dGK%h4g_`?_O^kt)ILx(Xu#~ z)K|*!Yx)rUvip!dp(1k@1=41!pp~<3YUmQ{c>kw#R2JrOd|KTB7sly{kiBGTPLxXr zy*!vvV&@O&2}wNZqt3iN%6Ldce6)!0Ko2e4XPOn7GeWlckyW7T(;>l`Y=ou#rJiF2 zr4H8mZRr>WqM=Sp=^I?{p^#JoDoGGNZZ2_4z8HG>c5L&t70%YMA@#4-saMr>Mn?z3 zq)Iy@u`7wBjExRN)`s^q1U+K`aKZ2Pj>Mbnf_*M-45h6D6TeZ2tY61=v|Il0a> nnO=YfD0&KWEs2APam`W{Kueh1bhFQ=Fa+aho#hp AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If the batch contains more than 10 utterances AND + # if either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + #logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + #) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + ''' + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + ''' + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main()