From 6c82d8469442eebac3d723c20bad652be0ba3db0 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Tue, 20 Dec 2022 16:16:46 +0900 Subject: [PATCH] from local --- .../.decode.py.swp | Bin 4096 -> 49152 bytes .../decode.py | 151 +++++++++--------- 2 files changed, 77 insertions(+), 74 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.decode.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.decode.py.swp index 492e84d93e4e8f7d16005a8d75a7da73ee0f36aa..0adc33faec4d7255f9e959034e7e30a6a9969ebf 100644 GIT binary patch literal 49152 zcmeI53!G$Cedo(yFQ4VHl@pKxl^Q>Z)7a*Qu9v ztERi@v>MT6MP#E<#1J2tk1hm^1Rr385PU!q7FUSIEbJx<@fl-49*c@(zyI^NkE+|R znZZvqbwB;P|L6Jtpa1!v@xk*qo$H*M8%*&Wo=ROfa$Rxdj;9a!-pf;|Qhm}L zTyxTDd9~290e^1W%br^(o>{N&-qx(=D)nNa(i!G0Ho}3adULnmC=}fRujrNwmCArS z;SW^2@uufDT(>wmux@1gK%?2NxutxoS*ZD?cF}F-{Z^snR$SlDmsan|?^!)ia*Oqn zn`_K;rq%o1qd<=W5el?h0_Oj9cVr*UaZ%b-X{+xKEhRoXM~a|J%mhetw|i{qxQH7n#p%I^Ms)xNkO}*LS=> zz_{nk=f8ElH_p9}9tC<7=ux0YfgS~V6zEZ)M}Zy%dKBnUphtlo1^z=)pioMsgtvaF z@U!gyTmJuthow^A2cHBJzy+@YM}r4noJxHVRKadA1Wo}zg}MGSPz0|8&%7{|dKx?i z9t7_P1uzVb0nY=ElK8{mTJUx-4$cBua3uI8f`RXWd%%~#UxGWp?cgTxVenp11v|j; z;5hI*;8}zYkAa84{{Zg+zX!&^I`As+?+7k_4E_#$0o)2M2BTmI91EU6`0sZwM!}h26*vaG5PTJ3#aqEHZ~-_U41-sK7l22p>ra4-fYkdD zAo4ioN!RzdO{X>KI&Pz0oODpbm~fqPvtD(^luNd-XJX9BxYNZ-yX2PooL1eu8Ox

X`cr9nD-ma7!t;FF>d96v6T!C~86K=_=qT5k%!l)+Q zN@F;E!K6#3$XVT~)C(mi?mg$MZ#SE6tyP(k6zVlT{dS{KZ?^nMR3+3-+LhL@%1V&e z)=a}4_FK)>0-U+yTx|P+v7ULNw>JWQTJ-(c-$@5o30{D=2w8F5OVCgh%vQhUvp?<$@sgCD6-b9VO(8Q=UuN5R|2{agzBauK- z%gdn9z7r)KrSaB!%N=qys=iySx1}|y&bun1{H>r#`JHXeREpGr>Zmc|TS}~1j1;Q# zFM}2&M+=31`C_Hu`@{4t)zfn8Oefo*JFtHaW5QY1v#KH8Bb{NVE6g_4S2HPc?M%3> zy!M&t8)Ed^a^BhbjoU^DM!J(IuuWuEx!F~NeW{f0$Wu+KGw(K=b*+cy>&lOh7JiHN zqBYivO|Q|ip4?Wuktfaxx2c1zYqgr*c)R6p@QU;*RWU{vF2^WPq3u;=%yKgEcNKTq zW2~@mx9FcHi!vPM!}hLIDtoo4qsfsO3nIS#ybQp3JILkLAg7~9@Z6E*Fj1Q`20>LQ zwD-IqyE00uc!ROh?{~zp={BNHrg92ZUd#803ZM#Z%FA~(3g*Y%LN)KZg=TS5=Xj#& zx}_Q8lpim&5?>es75r8-hJ3x|^4m9qUAgmhP2T0}%|fx_&J%<7vgPcBUh8+UoGXGw zTtZr^HwC|$FBMt^irucb_T?7QnjBGDlk5EE2Hu~m-lt3hqD2-bbG4dN@SR#M=4yZV zOj>TO=r%mJ=!!qTQMK=OqyDKYa;72kxQSIYTW5<0m7Mb$Ew5fHRQjDA?MB7b-W&8H z6gsW{ZT!?PiuH=jH-reLuqm^Jz07VN2~MlQNzAq`|=;Ak!E*ZaTT8sKcIm^5A6NpK|FJ2nC1%Ib-QBI0Xu(^zRv53iSy^>t3+)L_)ZndGYqVQg z#zAy*CuUe&l)U;#p-M~em7gW$Tvq~$yaZ9xy(SYwN31K$GN7x=JDpwhG^bQgWd{4y z!hssf?g*x?)J~tJ<&bJQE1Zf~^-w8RZLwXWveo3bg@KsWST41~8FBU6%Q&eBWu0Lb zU`40GqPR_r&V;O%Ytlog{;KEuDpD2>#B#x=9p{|4bBA-my6xN7ZQZeHy`6GzGE%CRh2 zq8?~Utg+JTMx0F}Y3Iy!Bb!G0iDPF?RH&V%`{uSMs~gskg$5}UrIsntl%!cubeAO= zbs%#eMeBF`dbu@KXu7Et#3RpStJN497?_%x$`w%Po^*5d<^-DDih};YmQCw7ZXMZ} zU5$dd1d%09wJ_t<$6GQj6r5t6Drc!GM72aNwD37&eAenMv$~^jWseNPZ@7(?WUE*g z!W0;}Lgv9lBD6AJql8+g@@U`mJLkJisVIm1IT^`A+5}4Xou)D@4Iq~r2>wigR6}58 z(xfpKb*F8*6NP3;8bB4s!lU^+g*JV?=?BD86kFoIi`?48lQ#?dO;!pz6B(BFBsoMEtxdeCU2`_K`9h_qRvbIf_L~FaUTwgw z?Qt42_^Yi+3IG3d_~5U>FAM*lGmPL*;qz|=H-LA8wIB-)15d*1KLKt9mx39v9;^gU z!S_E2?gH-tSAy4rgTR46cM1$q?dQJ_bG9tC<7=ux0Yf&VNNkhOIc#v;>~^2$y=UxNwH z=bhnUC!NpBi+nyU!anh5qzXB&R<38#8;~L#n?~3tujC1(_wox)j72@iP7zTIM}KH3 zj(Cx$C6XZ@ft4gYY=6c5Aaq2oHLO3vS3W{E#9%5JvwFR-p)ole{~5+5 z0XbsJ(OKc>hQqIMFMea#Z^WHouieV9`X7HWPp&1UQbm`v6xlI}k_eHi=I)j73hHE1 z5=^-KfI|9PsUlq6rCuXrUm*WeUrQAc(n>J%lSoI^SA=kdTAsA^o4WAM{M;3KCPiPo zr^SC(`A_#d`j_&a-W7@3MnH(jdE}f<7U7!6SYaA<4z*{tZG2r%u^cTa2AzE!Lzv|- z-lRyvx3qIeqDvmeloTd>OBfX>QaS?6Cq>L4Z8L?FZ@DIt>}D-5877qMo~g)II8&}u zL+0x6#_>w*BGg3`(Wt=eH#wv|QB(3I+mUi64_dOP`(Qa8?p&!B=%_~LvI*c1IkUBP zRRs2h>KxsaLOX#28S&fI%xEKrlUvDXxlcpdl8Pr6#*^vW73t$)#~_NWR*oof$9q=v zt&2QyC?$UE)8%*f>KrnfJzk-r`)NjZp4=V|#gFHG?^4&vI;;9gWDjx@da-=OJNo=N zTP*#YUt|lycTO1AwSvmu9{{)YJBPfGy;8oyh6aErDF9IvUA>a$}^VfpSK=%9(1;XEd5qu0>2wn{i2Dibt z{~356cpW$fyafC!vWa`awcsLf4tN=Of_?cLz$7>yoCpp8KV-lC!{B^y4mc6WKD_L| z|1qe7*Mn2Q3E*h(6ZY2s3S0{=1V@0U*+>60m;xTy0{XzC?4REOt^&Kkg1bXIvZ5wt~D{D-B0~ zEaDwo8O>H0K>h3vlCZUxv=10>KDz8`w^3Q-NN8bQ1$z1kTQ1);>{k)9T5k&5u?t zm>W=0X9H*1Y@h4u2SGHdj0;Ajsn8VzTi`C6!8-gI!!moQ??A3DPs>`(IDJ>_&)V7H6psMoxO3erp+wTC(zZwXzkJ$OxDPiXB+YV4wP zqBjG=O?iKB8@!o;{Lw+Zpw`8qK<0&5CH3Fa6RdKHpB)N4!6|D*7*#0AVpA!un zX^z&B?+LV!Xjw6`JLJ~t?TN{V>W}YaW|=uNL7w{NWJ$<&nJ^u3Qb;stD4=z8excHK zH;Q^ke6VvX%oy?LoYN!S#5vJf6<-iaoX=s2VL_?h%4(m0 z-l{#swNC78ZnUyQ`_DAf;j9^g#HDXEWn&9xd3hm&lg?cCXp}r1CCqGLi|AZ@$E$0> z$tyNnl+!vHwR#H5-??_OMG8EySK6S9XXc+!JT;5NvnA0$70;0gAYN0#Onb73QIlLo z^~zu@*V7-(6du}UMLvV6pMLX>%{jW^}u+-m+fy)N%{M?JElEy%ZcHwbQd@ME=#@m;KkSDB4{X6`>&oh8T zB@|fdm?-`iFXmFlMu{ojXT(LG=opz&E|GuHKG&(4xBM|uBA#!YH19-z!gL;VYdh5Y zh7|w58oquU2>%c3tDbN1`vGtpr~r9?Ja{>H9G?Gva2*%}TY&?fh3EeaXo4}21&4z> z;rp)xZvgATOTbg`{C9y*f;{L4hXK*=|17u$oCQRm|K;Fuc>eprbwKp@w}PX=&*Aw$ z4QfF2_E&*}!N0-ti=O_)U@KS!UJSkozyERIgKgjpFbLA%2?~EVxDMpOQQ%p4{?CCq zPzT$=0C)->|I6V0;LRWp-T=0N)4&O!4?F?ie+#$_TmV*pN8$hP0`D+#0I8|y(~|-) zA@!sH(Iyeh)jcUd0x!^$0z}%%qD!|_#52miO0W=FUiG8^nm);k${6t$(VEnCkdj1|BN<6L zKxae^djJ1D@S)#?FBSfOv*G{02%rB6Py^?Jr{VAa9=shifSmdNUGOmc{g=TsSO*+% zFZ}%-;KSf-kOqf?AHdws>fi*zR_kRogH?SULft=fa0DKmdz$?Ik;Ct}(cY?Qr z&EP2T7(D$$K+fjh3f>7uzzN`=;psmFWY5mcsCs6qf5vWL|CeJC!uHt_Qd(h&| z4&6%>u*8BWmU~fE3kj02DYdw|g#=szHyV12Xr+i_^hu`E-l|_n%=<<7A`!(l)>7gU zeWLXvv+G`}Vn~!hR8~|ne`Lx>w?c=+vLuxZ@n}Zm{;w$kyGoc|H?fqH&ty97CjpKmu)kC}B(CLk9Lt_}s&q(9-901lout+wL6Rw@BHLG{ zMn|$=R6HcAg*gH7v3Eiq3Hd9a!crL?$XuVwwQ*QXbooL`nR) z-{wpn0h6c9gc*3LY9UfE9q9rSY^*abPdy|?xV;w+Nilcg`@tPR?lrg=NPQm(k{{_2^6d0E(W6Z0O-;HrrC)7T z?j56YFIG7cEhL4y=GMyEh@3ysrbdXC145{6V}uTCD*DKZN<(aVU;qnkZl$-+GPyPB zhZf7M!Lh;IYr+JS@wGAoN8CQ0iXf{_y)4J_SYm1N7;{o);F1e#%t6KMukByrg=eD_ zMbz;*O(9z&-9%bFxR_Zs4osnb35AzIppJ57vo-_yut~&>_Qa@S8^gJUX|LL@$|){m ztq{{m<8=(5VLu0LY0kZL7id*Sop)L!?6SbLP_d+ehoh*@*A7k)U|y}|PNQru2enYC zlv4$&h?sfAhEkRDhS+8F9c>Hu=BY~`&Su|OxQY{wvnVw|!B zTA1ds9BRZm+@?<3j%q;E3zd==(odT-bTVGf z&GkEDoUN^w%PhZDACXBDIg$$7*Y8YCdKl?tI+pRuXZRFDetZv#XDef!qev}NtTHm| zaHy?~nKP7V%~>jAZ&s(KCT9oRbP_-k3nH{#dNs8f7G|EwSVAsibYnf-HYcB)@YjdQChJqsxi%wqYD*q=7dceRA!lW7 zb~hIGR1b17c^1TH>9N=dC)ruG?n5}q!QrS|L(LyY%jj7|9V_{R; zR8}U{+MIMrZs%%+BUaZYVUeXZXd8DkLGc_K+q^-S;!B&%T!PO9y~JK7HItZ!NY%&O zV+)kvjcpOz>e%lSV|c1-_*jySocXrZoQ}m2nfO%?z)WA*Coq<57fGgBE~6{xOv%bb zDxFjHO|?X!mZe9^aau>&c~HyKEoWme5!(=v>;(B0+o1#MIPG}D&A(=FFycw&c^fOi zSf*_&TXt?GCTT-57M;0mqh_xo$<3dq0rHkp+<3%$B1Q4IY)jlBEjftfHa#)dEfkHU z6=ZFJ>*v?zi>=}SQAENe<`FN_ON}=~{)^iR*S6Kiq|GV&=CDd~x+geilyf49rq>l) zZDqw1rzYiwxXQGSF{(sdlgu?&-Ik*{o5KIBF z1+WQ-F2Fy4d%;)1SAf_AxDtpxK(P&QG}5NZ{$awBaq@(HLsj`k(nMY zy9F$QV54AK`LMs$56!H^#e5+HBdPiDfJgyc&XG2m6&cVvEb@z|id$~w%`({7&B;aG z!XZn17bn}b-O9>_^{bgTOkLX-@c;=K@9|~=P%VXr zOyd+$GtRKFMjj?8C2JH)!x@ff>p!!97$rns+`p}wlmwQ?IFLDO6gpRK*OV>hLIo=+ zkxUt{y4Ft&4L(GwA|+8GXf_$?u1aG%LA=GH6LT#LzI!}nBX=~-_>@P~t zJgQc6X7PA1j$GU=r>JwR;TMZlhWkh+gfEEv^hZ|`$I4xywq3NC)a~dlCnYqeYB~wG zS_fLu5h51*g4GfANm9zN`e-`pYKlgw7Eg!en^H4W_xHi@*C zX!WZ?n5}U8+O~kzVxqXyag7QQ#%Xe=>{HXU)LXsHP1b9Ocrto0Qw>;|8Orqa<+Kmw znMJqy#KIC2t2TYabSy@|?817xvWKMJ+ZCR%ag#F8p$21U z?N%3@u}%X>V^Uk^hy+#mWRp1?i?)#z7F<$Li&`xY)N5N?qR@Cu7@&)wio9d0_9G0? zSY@)_!m78df-wn}+q_yE&EFncrxN4&yc*M0=V9WiR4=M4U6etEfU1a844*T%u4Bk1+gua--vstxmbd)m>I;xqnc{yZBuKd=JaDZz>xNkjOasS z;SCA_3G_IvqgaP6uP`#5Q~tN*#tmgp+sHj9xDkcQmHQKHW&_WTc2{!JoHM2-u2F9n zatxg_TfwG31P$NIV%p!vlGU^ua-gZDbJft{YqcF-phB>eMcVSFxt1M~z6Ajd>hV;{F|Cp)ejm_ZXEe9RuA%ZluVKm}@Gw z$u$+*n)M=9?CXZh*4={(E^_Cp+b?uT)8w6w+g4!lj6`4G^&!nW!@^mJYsf`NST*g` zbwP<}PxDsKVF`Dd1n*Uh;c`K9pR!^LHW?5u$=K<;bTcDkv(!cG$7~@RAqXJKVgti( zkEUy)h!A?q26~)yd(EqWI>{jT{xTn^o!U zlc-EAGSb?E<#&8g4Yx8Pmx|eiY$OmBVatl0$w|fu{tDdrz)c=gSU8M~L$)&1#UmZ1 zBp2N`i`nGl1`W*uEu(7YmB%f0vr_i|kEH3}VEF$_4F7m9eEzjS_WmybVk6)H@Ll-% zTfuw4W#B?^5_kfh{VU)`@HQ|C`oW9Ax8c`653T~^;Emu^Z~zcK-tOb?)!vG$gGqGVYrqrc1w zdK(s#nR@tRMXO(!gV@^i;tagaL^!Nj7x|P)*7Dt2_R8!Uc7v$b2oF6R`JVTZ zy{i#E$uzCRi6BGR5fV3Y{*b+v*HImy4r(alz>$E)N}zqT;TB1PU$#gDB^;sglU^rYg)0(KB!g;TD!hFDXiEIH&elKH#`;J?9&ZJ z+>{uBwP9SF&_-R>ERe)2vr5E{ky;8US&lXTE1p8vc*;51LFd|^lz8fHv?A=XJU-uZ z7^n4GaD2+U1d+t0Au-zf4c%+p3=a8LeK!Jyi1nl}jD%Z4oA0DWtj0yM7H=aF?@^y| zPl+8O%r3OSN7p8rZWoJnh@_Am4TF@P8*)YpuSJTT4skL>-tG&nVtqj07`HT1#R8@k zV{XRXu7NauXn%e2 z-Z{_OO-p!|*7g>R;#bvQBhAXzZ^S8HDZM6R*J7###fh`GO=e%($^wmKH00cwy2GiN zAqykwGV#e1nKZ+>CloD(9>sH{UG$oY?69STD)#wRj4oTw-j~*bj>rNHhWg)si>h#n zuq`EYiIXilK}7xj=a@8*v#HB!b+swpi=00)A9fX}}gh;IL-paD(=r+~F!1$Y*I{}FH>xB*N9xf|f6 z;2C)SUx3HJ_rdi*bo?to>;rBDr-9?Z5#Vv;0RI=f7s!19a)+VZ74S(Qa)D#O|AGJi z3b-BI2>ujY2QC9uFaov%2fP40gIwTY@Evd$xEf4?auGtMz1-u>WUAw z7>wG!?Ya#+IBCTlI@}4O0?5TJvh*}p#)K1wae$9O%8P&_)8!-dE4enSZhV<53Cd}e zil3cT_p%^kSNHPFWoHq>^WOyU+vT!1J(p{-#7)`ahxsFY-M3_>BVfM~qay>v+-~Gv zpd4F2RbOrsmq>ynr?|d^I7G_M)dA*hG~7_9yj|U*9Ff0^=-r6oUYbZ(T-Eqs1%Bhifn=C0ODS$j3jX*Kp&aC0TTVkr4t!Om%wJaVP6t~QK6{m*k zBmz=(Ws7)_Hjz>hntwqI5Qf9o`lnZ#o*PhM)NhwyK_JGpz>{;lKn@#qEgsH+b zbl_E%m?~FjEy$LUYF!=ao3f4y;dbR{eqxMS+mg^4qsXkLWV^09-BqNJ;_pu zyNYB5q`pR^3KCkLz#C<@440iQQ_6?=C+*C}9Ym;;k+Mj$lPi+67oDiOCW%fI5m88F zX7gOAKtz6@kSaowXrf{H$X?`hpJYbgK_W3o+VZI`d8P?ocVinvws?l*LLt89JiS%NSD}k4XQwX2Xqva_h}=uHMk1KU zx4br*3PO41i$*rxufIwI#?=I4k<9&u}gi{hU zE+*^n^BV1E;YCGU{?XA9RRP$);$D{5>fUT_Z{ES>>B`uC$SO$34gE_yrGP4vqGF>v zc+k1&^XgToqQPz1QfvkxO%A(XEi+fyYqAZ*kfe7(%Wf0tKy*UW+a>bWr~277)qYhH z4%@Z9+2U8*K0;V^XSLjR@8_J2D1kdI`4QQ|o~lDzSM*DRjdlX-?PY4#YoS~|BU^uV z$6EAtM53X$*P{E;wj|rNu8H2e6}|5|%G}IIK}_+gkO;Ml+4WYX(UAmsf`RKPny8)V z#T8AYCG8*G4QXC_+NHU2{rb|BHXNR9Gdt=iFrfS)dMt8Qntfs6bc{NS?yD`>||6^S ziD)j2TztzTot))NIh~JR<`$$BcaAKaBsAjyw{tCc$CY@qPmowdS`#KTuCBP0le`_N z1i|t$NGrK;7|OcL#Q;I{p;IJJu>e*U5DQ`*qnVOWLTFWW(R#F= z)XvHBE)cVXwKd(L929LDdpXrg3et4!+BCL()n`|k@zt4x(f0rPY0m#__Wvgg^C-Ii z*MPTx0U);je+Hlb4R95B8yE!_fkVL|K=%J{0b=`K?g2a*90q;>fBzY93zz^W1L64} zhL8U`xE;uyfPVpQ03QaIfo)(FI20TL9)-_;0NfA$6bSEsDUh@Mrvtes;3o&82LS#W zyc=8rnqUp+13!Sz|33ICmCX;iJqCMO02&=FEs?(qdjUFWDCkNt`s)dS;?& zz;ZOuT45x7bK7I>#euoD4_`+2+EcGEsMc#?rAX1uT1Y;O-*L`C z&F6c(68K)>oG!oA7NCknNlBPKyc*Lyw2~_vD+7p8xQJMeU@3uru7S-BNsJ83d~R<6 z2|Vbm_@dx2zZM*YQ<20fuS`wJ-J+<1XszZLBOa(iYn^%hGXuj?DsVrp*+yqT(yzri z2ZLPLQ}8OHv}Xk!^F`BWCpM_C@hF2aZpn}u=}cQLxt!n_X5_jc+LK6g<_379&(OJ@ zf*YHR(9l$v7U{^8na;B0<~lJrrU|!X&WkBZg+nbqy6#0$UJ_$0r?bpASyVr=sD5OX zf?kj%H4{aZLLw7VSKf%pOhG|-^INr76EwbBNQ%TxQ^Dd(D*(-9WfD``b3y!(6^SaL zp3StY(u`fwTr6~vCn+;(B@!Ik4ies<3zWoUts#pr#bC{mIQoQ5glbUB5Pi%%44=5t zk+sy)7cEISRgFjOh6hO{u3zR0YnljBP2)pCTgOG3n`#g?ocu8UrBt1K;DRk7rRsm4 zCJr4N6CEEN-yaKPq}|ENHM&B zaG|a!i$>A?4}!^$xp^WA#~$&e@3vDUHQ6(x9{Dp4=Gr`Gz&dpFbrm>V-=!1vcS0dt zS7jU94O;iW=$}dz&*u?#V$s+Z@)3O(9epB^#)TRBMcgoE2MuI~&FERsP_5m7VRJb& zzd4?A*m)lMEgsU6k%ilHZgBjLr9vlWAQ0D8Jg9%DO}u= zfl}!K%+3Xo1_4BPY*sLmpNnf}5%nxP2N!JIE={caSR`JXrHfnSmPJuCHg&+d&NYr1x7qaBjwshz z8c(JB#O_8}(DKG9*tY?%q)u&=OHMTOVAK&U#eVj_UAmEKbspl$wmQ zd6jp$v9{%FqDe9x)o2UQRh8m5p237F6&ftJE+yg^?SeX$H!CP+c>vy?Ap#+`mE_3_ p5l4ncfmBcM9pUR}t58K2r?Sx!sWRl&UXlKVe7jKbCg4g^{~zqyoW=kE delta 16 XcmZo@U~W*@@RE7r1J=#&n1y%%I<*FP diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode.py index 4e2efd6c0..cbfe5ac34 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode.py @@ -659,83 +659,86 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + + if params.model_path: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) ) - ) model.to(device) model.eval()