From 1de6d4b030518f523e10a28dbc0e6fa3918b6361 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Fri, 26 May 2023 11:49:53 +0900 Subject: [PATCH] from local --- .../.decode_new.py.swp | Bin 4096 -> 40960 bytes .../decode_new.py | 803 ++++++++++++++++++ 2 files changed, 803 insertions(+) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode_new.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.decode_new.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.decode_new.py.swp index 700042a1728cba4d6724f97b8a5b82ea1a285382..d75b38d11d64dd0a90ef9621c35b68bb60867b36 100644 GIT binary patch literal 40960 zcmeI536x`3dFSO|7PFW@5&}4wXEnH0xTLDCZf|9mq3!N=*U;{6+-;0q?x7QsGG>xslZ2chj>BXEoEe)L2qX~5gvqdk&12vU1bf0E3=UuxvjlS-mIRpJ z_uYHnBB_?fhMD+Cr~W1BzPo;R`?h^yYkM7z^#2Gw%6<2R#fuM%|1QP=MV^&pDM)v06UV^b5=pwg-aTCA_%~l z;8(zBPS53D2O8igm*b6QPGvF-ni{M8H{da=5fj>CxXX-M}x1UAo)A+Vesdm0}g;3_yNk5?}N{SkAwGs zJ3t4V2ku8X^EvRp!9C!w!CSy>;NO9r;GgN=zXP8Fp9FsiL`U*UPzTcg9vFCe_O%q$ z+cWvAmxAt6(DAx0uio;jo_{Rp_=`c+Yv6`aFL{^sI-Q`|t*>}t*K0L-k9zHPtJ96V ze9mQ6#pR<{@6H_V^n&p?w7b#{W}qLCFQ0!D}qeHovu44BD;A(pJxJRz1Fj-JrA8<{)wNl~ywfD!p!aEbuBzLFH(> z6*j5yLQrdU0*?l#f;}8mDq$_8%?m53Mj86-?===6Ppg(n5ecnh21< zbgpee*jy~d#9a>Sbxqbsqu_YETn#%DNk&TbQoEa^H3ZA&XKP-&6@^0maq6Od<%N+K zE;gwPLyTSvyK#|@K$8hMh7w46`6MXx(20^y(|CWY8%%qG#t8$tZUTAGgKAwxGCeax9lKwOZ~l z>dHZ<(=x7Vk!fRgQH#0^Yva^b>4fdBixhNw?J{{T1|5@ZXSdr47kb@bS6E^4>-e$m zawQh4I#Z+}ORHDNKC1`ELzXo6bf5WZt0=2>IhmHGq*~Za2ih9hB`Xy>DldoKrLtR# z%i3!uiy@Kws+<(+YQ|cvjl#h%jH@e)x#k-$>Cw0+ft{e84ziu&H+U?cQ(l6a+_sm` zx>%PN0>4p?0>4vPGBsZ81VMGh29+25?$8I8Kz`Is=TL4n1HMOgO|JDnZpyQAtK(Pd zc86Od%D&6jC_r;C&oVVvf~7(dTIv%&>XfT~*QeP%W~m(@`^2&GkUCcGj5@n`euX~I z+2@SN%2KJ>^!&(cHZ!5_OT^w9G%G?j_hrXyF{tWciKRfny&FN8|o zfp#}+HU0XycevNC2PS%(S%lB0i+_->=0T-Zmu-<0@ff!&NtLPKij_&|KFhqJX!q4k zYbGu05ss|QK<9n2cs*Nt&?#2KhIjsX7kD<*xREBo*V%Ad3}3T$xU6L!uBe_%(D08J zqm@RnO`=*y*2rz%b(vDAWw8Mh)v9{WvG3``)m{hYfZ4laax%`9?UWjKywYY|VlO;z za;Y3G2TYe^+soT8xFF6(=Y=k^o@@#1#_=O27(1}^P;mj@-J_oHX7SUY=`YlSqHsP% zzqjb^@$0Y(uG(R<-Rl-vgKmyA&ZcT@dH^S0lE{8yNyEC3nWw^v z49sV*e{x;qRa?2j%tUVejgl zS6;bu|KZu)hr9z>{nO z41iK;BL1@t(t!Zm(P1!F%!KU(i+-mn1E7I$6EgfgzsEf9L@}|HVw?PTZO~j?4npr5 zzcof%E`v^;aB->YO-)X1EAns3dr9cG8X=E5umPqL1;h)AoMwtP7q9O%y%+ZQK&R)n zoHfylIui?Fb0TOS^V%!KYi`ep{QoqT$)}0zhy3sC`|m`~Z-VE7?;_iO8~h`<2fQD= z4IBqofdgPaI2W7=o(TRP`Tt$u?ciP=P6U>6sz@xy=fv=(mxC{I@a4YD7Yr&=9 zEO08g7hS-+!0&Y}j&%grkfan0e2fhnF2krvz25$f_0;hue zk@4RGI$$rj2%HOcfHT04k>l?Np9X&eR>59y5jYp@0B3+xz_*a)KMQULbD#*eg0sPs zz=O#1UjUx~w}Jh@1J4A{03z?-hfM#s;C-M4LU0gl1J4Fe1S03(3SJ9-53GO%a1dMs z9t++_pB@Jt@KW#+@O*F%kp6xga4$D1mG{rNZ;pSTES+n35qY2td<{l2E0HS4QjCym z*X7|g?-Ub1oOOaPQo*l(WiO{WtBj~|Rm4iB8i(bjps7oN6Pe=8gau@1aY@^p=#$9t z!laOn6;X?hH_TUU8zx9NiJ_7Z%t5|fcD!VN*%5ISvs<`%C;U5nh zy@p2(tbPJ6Wub)(!EWU#r!v@}Z3$h5ir53p&U(#~k&|@0sn=nVc3Uu($B{I{m%*>5 zIu6AS%u^%a@Q=3;F2YP*m=>59L(V$&AT=rNp>k6ST0y z<7B9j>G|1yorx}pX8q71>+QS3iz=u9qGZCY*|VuR-an-5O1SUxBpWwSaykT3qhPW$ z**}+utlu)BiuJ}xFX5T?Z#XpK$_;POre?&N^h##Jz=TWseqO}qd0A^-=+~=ZLO-J& z@CsomD2;pbXoOp}nreFx>ZHrOB06Exj*WZEOChQ~SW;mcBN6nFiTNBi&vnL5Pf4wU zfvAXH5c#5^DwS7b*Z(CV9;C<&2%+Hg@N$6ENndr=A+26 zI)efPj5t3>kBwG1sm{r+_L-*77mFh0i3C3HIzC!z_3BmKyhM(fXBxJ<^T|*`sVxFb z4>5W8IF}V(qb#90lhI%#%9Xfz9)oSoq??Ilv`?Fexo(uq!(qcD6Cu_0nedDZA3t#; z;$RDDSwv5r8WBoa1a!~7TsP-4d89UeodF2N$(%r+?3g6GAe_WhFqIO{L^>UDy5o9T zhA9V^t6o>-WLfFm2qzzg+27 zCdeX{ZfG5?p3Fuws{0MBgY#+}nU^uA7EOc*;q{F=n;tKDsY3HRE8U*j?NrbZax3jv zszm*uS*2KPled`a$e+=YSU~flEYj{WX*l7I3Q2XS=|lxo9ra!{sE$eR+NzdP+4LeA zl-6ig@>q4$rYgqQWnG|*S}oCzST()1*HR;q|3&Qlm^jiS{~xsi;X6|D|Fiilw*NAwR07W@|Y4e;yWdT6U?8{ls6QE(eL2*!Ze{lAAS{}15bgWcfA$nIjx7rFj5;8oySa4~p1I1PLSLVf_e zAKVJcU^jRk5F7rRz%^hiI2$}3d<;4LF7Re>12_aO0jGoSAfJB~ydK;H0=PM-~4BGrRhw`|$q8RifcmQp$_ zS+XQ{WvSOZS{7l-#Wh0?tC4$n$iN4Yvv3;b>mJD7awi{PFm!B!1lZuKxFq`sKc zURiiF^*^xBN^$$jOlI_z)k01xE3Yg?bjO*ohP|XfiP-k5Glj`<&-@J`#8@`IZBY_v z1E`#{uas)NrkXYUdWr0lN~Njcoijt(17y!hokDEo=FwnRVxg2M^2%g4BQ6+0=UAv3 z>A5^AfU6jFyHp;ntIy8U8^!gUh z?=mx>4g!>!S|Xej_G54x_bxLubPr}+lyvuNBQT=}7?f?qcmzy|L7+cyyQWC;w6#-f z5JO9!nz2RIo2YmSW{&Cr4lvtCg|V@ciJ_5Uy7e6mBO7Z6dw7TIU*H!ivDe6z8mP}5 zheyN+eS1kwHU0t{E6R^ZE$gD#m)@1^c*kCdr9q1@`9!Hs9jX^}7Tonzk!b-uQ9mXe z=sF_&TC#4kgdC>kaYL2U#Kj7e&O+d{#g6%Nk`~dTX!97b2}$lDb?> zl%(NOf-<#6A82@6C&$QxvU0mgf&qdmdFQqFQv%R@eW}&Oh9+Ea6m&~3oHF+3E1BYp zO|>=dGLPm>*M+#qR9h8w;!<0Nt+?l*Qn7Nj85355R&GhWd{lt6v9%gWWub$=+pCwGeS%Pvrd~Va7amwHP5|z;=H2P zE35Zh#0QGb6;}jTSr?ag4OcyB$(z@WYc4#39>dhkb+DZe@xTv@Sjk;3MMJx3p{Z=~ z$gV?EQve%tK?@n1_b?P<@nsu}GqM@xRp2|{&0ANDFrxAV_crmjFzlyqgGA2@<<6Ii zBOh1C6K-`OQJ>hXSXDD}-_S%}pKKV{vWV({?Lr2S(KMh`ki~!jsY(YaUJ2!hSMv5S zcY@=(^NgFNwT3f7Ygzbgvxb;1WwF|&cTZt+2yE8~k8cNvlwuYp*-w14hj#D_-L5B4kbkSoM{d6Nd@;dgO@RbX?`^)Qx>9=9u_>wYR>R*lq@%|tw; z>i)wbxn>iX8Kc!QIWRrMxubN*dVCxdj}Pwjs$es2St;904}3*RSEwV2Noj|?*XOu| zh-VhDz_`uG@Ge+x)sl;pY&?eE1HL{5^VIBNL)?Sn=i2WHyI`bGVO?V*1Ii^?JEMSU zW_4_ix~TowEnvGCBYusE2^X-SEC^jcmJ$M;#JZtsT}m!mM*1W~(nM%h7TOp_WYztN zYr5^cWJgiCx!;*oa}J7H-a}HAp|$HW+GF-RyQn6njFcO(V`Ni_=m^)=apo)Q1b;pr z7Wm98<2;aBhg@flvB?WIP;< z8Ru$eKv&Ih7hegy23f|89aHU)rP%)eXCgx)|IhO#uQwy>-vMp|e+b0?|6Cyc|Kba9 z7kE9m5xfX&1 zoDSZN?0yUAfrH>I@XO#E$n75iH-iN*3!VcWL~j2i_z3t9UUfJ8?2)(os_)_Dt$?IUF8ccY_)IKNw+Yr7fZhgUdK(E0!9}X_k&PB6+IpEWNDR|1IAH%bEL{7%&QP@r`o_mAbPHC$taTA z|J0~xT$t1^z%5Lvpe%JdL9HzAG0_a;-q16xF^lH%@-%;>PZufbGu`0h0$e~gd!>VV zf4GQA$Y4z88Q!zkW6R0T4V{hBR)`YaWcrOIf?IPOD7P?k79>5zFSV`ZH|GM7mCY~9J$%#w8^*#J|UX0rDrYg`gh zR=d==7!4J%Gv$OT?v-6membkmzRMrVXu+RkodHAtKYUD8vPHO&l7wW*CgE;_NXKIp|0Ba?gn>(cY-&9 z-vCXJ2ci@BZEyoPABc_rZJ+_J0#|^Gf%x}NfnNkaLjM0E_yTx4khA?ag9voMm0&CQ zKKg*a0Ph51|L+3v@81ia1%7~D;9hVKxEtI7oZjH+e7+xjfSlRC9h?9!19RY7FbVz{ zy}%E^8^P~@*MM8V%YhGO!KvUBa38vYcY{9x2f-fj93Xx84+DAwhD=Vzjl2fVo)+Xp zjU2n35!>scGL7iml;bq|>y-V~1~A4@1b2gKVIW3!PqEB4vE(qgH{sRtT1r$rA&Nru z0o}5cy;?kho}T|ULDZ|&!sDx@E}U&n>&?HUew1&flx%P$9JiWw)PzW zYG7qFdtO}RGNwoJkQjxroPC;3H{Gw2*E5SL6}@QZt+rl1@-iz~T!gd=ExL0`Gwqrl zrWlX#$l++Km@uOPb5ELUmEty3gCtm9_MBrWY}Q(ZTHct1PFS(K+gTm=uvX*LZ`Ac@ z137vg*1aa)mFQ4U2$LrfR`X+t&@TDLIY5|N-ewq9kFQQ>8hsp)EGLoW#d2oS5XpP4 z)T=CBogGX@B1=uClmf0$acmH)l8&yv@@-q^gm6bn3_o!V5Y?}d_xfT>1(JtOb5NjV z|#nEW9le63|nbR2dpNun?cbuUqM%p67?qy-y7HzAjXdm3>S|l_E*H(4>7-kM-X@lgbE{v~4=%C8xge9>>m3NbRv1{h!@> z6Mk|g`IwDz5-d>EMq16L9?C#Vlg!kx%t4YSZj($@wX(gG#YRCHGoNMB&&r4+09(@y08-R%m)#~%jY_`WK4|9}xx7=Kij98ht zNqiZnG*BxkUN~j}i6=L6ULPf192{SjX0r&H+3if)Vg;>*9V|sGNyboG3%E4~gM}XA zSKkd|QIV%y&C>!34g>l&w4?0%EAYsrEUtjkl|&vU4|>?p*QE(;qf}Sa#(Ipz&}J>UF3QV zOviTAiM36JuyukF)0|n%+yJW9&1^FcB03xr#{%a%lN;tGurW(wa?UQ>(;1(DhweP^ zu=@%KMIj25k{c=XmJ@X$Dv7huJ|*sWF6B^QY2-L3Np88P6L3drbYk+B9cLMbtmdax zwpz_elI)u6GEm#by1^S$GRm}*tn!@aEft~Tv!M4$m_Cs zSe7^LCUdAWS%W2dTv$p;lN6*^jcDP0rNqg~vPi2sI7#I~R?8WlF0Vyws%CNaiS!?5 zp9H1qlu6)J0GT8%FkLB0C4|}aMD*N8t^B`a<^T5}>;Dy~gG1m#@HOQ8&x3b>*MM1Y5qKK-CGa6K`XIOwG{70) zm%z7???ngjM$iNekO$vJuKzd?yZ>*4WpE7iKoz_MTmqf|zK(qVAt2}b{|mSoycz^x zA2=601IRi5Pl7wZZvwgVZx7fFir|~b{9gob1aAPZ0NcTLknzPw;J*Wr|E~Z~0bfA2 z{}WIFJAnB3{{^UkYrvU6&gy>(ydAs}{2F)xI0HNrd>@%z&g;Jw$h~~80ylwqa4Fac z9tA#)41Wh$1uqBJf?2Qye3Ab70=N@MUws5*U(&xB@z5+SW69#hu|?ywCi@a=PwWE2 znyFWTM4Z0sMM!W=#AB^G%WtchEDutEU9L%og>i6wfmaLr#~xZCp860IRR@^gbfYzyuWX7xUy2MB7x{1rAlM6?>r` z$KK*yEZ@{dsJZBYAT}Cq3xfA-LRpsg3HZ~nspq*=pwZu)l)|X6?EK=1CdVb@z?Via zRf)s{Vr`c?8mBeRbY6{~62E+Aw{elqi0rXnwb>`l*8-wV)z5Qc!4%8VWgbg#2~pZV z=7)9h0CXymwX$jVoPHovr)2Q(JDjo7 zsSY#BAK6^5Q;49mpPs(hAmalsEK zD>JlMcPmCxV$7WkU41egstk)cv!0|M$`c9KwR!etSWR{CsJT2YYAC0(R$4nm(0|$( zlG+7sFx+%NxODQx#+O!|6!VkQ{$FDVqv_-xBqZz@Vwa*!ZSCSTUdC|P?XmUl_4>a$Aw47#U|7GeiNAYrggsNsDCDd|BYmhuY)Xqm|7=-HI6w}v@t^osGDRHwE9MYvigOlw`$v8v1xKHKD7MfX`KnuVUfx6r2 z{N#ZsE|YUbIup3)5witj6{_I4TogM@s>{%DhZW`*nsC3))z}>D%Sj>f%Jh0IFPSHb zbe)sN@T-E%iNug9B85Bu|5A}BvH!c@&$}CW|1IEla0}Q2E(T8p4>K+f_1J@Witfkki?cs7u`{y&WTemgh<&IXSLa(BPn z)%Qv;4jv00L|*?G_|G5$CGd10=k?zM8lV843PgT?7q}J7f*&ET%U%2b6TAs5f=Td0 znixqsZrSK3`<^UjX+YlYbDr35dM@0&p4-AO1fDw}MxJn}E=^ z4fG`)W`xJo+3TEyreE3N>YfY_m8V9M!!b!@#$+dU3EjJ!B4y00GwCmzNTM=dl1RS} zPi8Nf*y!TuWFouj8M=rhjhaC2DNwUn&OLVT1LSiL-1~4ku^efoSH0}L8jh(p{9{() za5>iUm)YBzoKD|1(%qz?7kUg_<}sLAX0GO3_U=91C9EZhEN%Ik)rny&FINu@78GZt zZC*DsLzSVqSgy_FaS1-sTmdVx|3J-MDRyNwlqfnL7(RR;MV28XhVas?eMW`SpCZe@ zJ9-W8=;5J~W;9J4<>Y-hooab=5zUG02pX6*@nhDKlvJ25U{jMu#3*xlZE{0=G&r75 zZsJ(4CbK8}~C&cO=*Kcf9uC(UbXJOd~NhDbfYWO8I{ZhSAHe2uF)U@-? z=|2hKQvVF7kZJuxZYr{RUqg*+l53EAAKvNHeMHH*=ig6+aBhurNj zpWzBv)e=j{*jy&pVaf$p+VB-{_?9_j2Tn%Ocy<{a;u*E!2Ai3TwY(`#cE)(oP!}^- z_5SymaLzoo_C^=3%Dyyw0vlJu!DS$77HGf7?k)eV+xBc%!M z#V+*-TXzQsv9$GPK!N{X*W8@zYxka>n?L>amemD4ncD+~8JeVC+3CHBY3(=yor}i( z8HI>`Gku-x0Z9}yHgZy+2C_H}`^_7>>6m;nYjl3+f~Xfz`&7900IiT_W99#^B0Iki zxmx7^2`i?43pxLr;4i^4*bBt}|0~G(?*umiXaCRf`2l2ox$o~4;23xbkbC}SKn{EW zx&Qs()u0WY4=w?Z2X`a;e;T|Nya*fwTfjdd|33)+2lxZ<3gGPh7xMXO;C||TFZeu= z-}=*b*fxybMidxPU_^lt1x6GYQD8)YMLYh6fUw{J-%vc5A?%CRd>7FO$G`IbXj?tvc2 z(kBT0IZ5UsBIoNzeyY)LlG0|;!f~J9Qmh-@P@IjZ1-Nq~;$}CA59@BeX>@lRtHyR2 zkb)mM?1WtKAK=rj&!qk$E663NsmL5nvVH&>=xVg7986@BL()D|2?-sRN%GXXKctFX zD@{_4I5G$!Q>~+rwLB<{7l$TLPYvUHt)~%5CD)egvn}$g4^4iVQ%>A@VO34l{H9&I z&fy$2U7E{7oe8;Q>tKHW(VdIjAOFme5tme5qj%2YPieg2QNfA*llsez?eaUX+*gU4 zX0K70Ys>w^LA_vpKv1t8(tN~U1fK}>cKC~=?e57OAq%;TQS+6&vlA`-G%Gm{joW{- zt5YhboMnq*lF67gvJ-vnnr&2m$=)9Kyx5Z|>j)S(QL3ITR6>l;iwyEYeX;CBcO+e3 ze&yEuW}dZIW)lxk%v^I>lKWnV7z8t7ZY{r$0#y;2D%AC}R7+l|M|6l)<>@aqm}APE z%@{C-W$0*sVYie7<}_pb&;0-|lfnG!2**2 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode_new.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode_new.py new file mode 100644 index 000000000..4b3eac067 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode_new.py @@ -0,0 +1,803 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(0) for d2v-T decoding +for method in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless_d2v_v2/decode.py \ + --input-strategy AudioSamples \ + --enable-spec-aug False \ + --additional-block True \ + --model-name epoc.pt \ + --exp-dir ./pruned_transducer_stateless_d2v_v2/960h_sweep_v3_388 \ + --max-duration 400 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --encoder-type d2v \ + --encoder-dim 768 \ + --decoder-dim 768 \ + --joiner-dim 768 +done +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, add_rep_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--model-name", + type=str, + default="", + help="""It specifies the model file name to use for decoding.""", + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + add_rep_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 2 or feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + #feature_lens = supervisions["num_frames"].to(device) + if feature.ndim == 2: + feature_lens = [] + for supervision in supervisions['cut']: + try: feature_lens.append(supervision.tracks[0].cut.recording.num_samples) + except: feature_lens.append(supervision.recording.num_samples) + feature_lens = torch.tensor(feature_lens) + + elif feature.ndim == 3: + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.model_name: + load_checkpoint(f"{params.exp_dir}/{params.model_name}", model) + else: + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main()