From b51f1f4032acf3b453a3e3e2f3199178ad5cdac6 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sun, 8 Jan 2023 18:32:40 +0900 Subject: [PATCH] from local --- .../ASR/incremental_transf/.conformer.py.swp | Bin 94208 -> 0 bytes .../conformer_old.py | 1593 +++++++++++++++++ 2 files changed, 1593 insertions(+) delete mode 100644 egs/librispeech/ASR/incremental_transf/.conformer.py.swp create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_old.py diff --git a/egs/librispeech/ASR/incremental_transf/.conformer.py.swp b/egs/librispeech/ASR/incremental_transf/.conformer.py.swp deleted file mode 100644 index 84d33d1fedfc29baeacd5e1886104d082a86f5f3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 94208 zcmeI534B~veeVYXg|HJs3uWoW0y0XZu_Y%V5h#sgJBb^|4wjMtCP*V`Buy;MD6`ly zA<)OZwG=3nrGz!dW@Sr>&)rI;#jfI&BfZ5t*4Ia)sa3s z2v_UYhcB$op3`g|xTw{vEH!7VOT{>k_X%!3)NCDSuT*Dio9nZ+`RdZr=GsDg^HP1L zRd27?NbEB7otfqLEA)$J1ZeXGLvuejfjao?8~-rvsso^juAFTB6K`#tHtf6Zm!^SOPD9+yuYh^zSDhwukaqxjQKpueV^$X%;$f1_uS|I{KESa z-22Pi_csgg@8O>N^e=D??$bZfyqVV2Tx3t}b zyRPuwJ>K}*Xn~Cu*l2-`7T9QkjTYEwfsGc}Xn~Cu*l2-`7PyHmP@OB4#2Y?Uj2cX) z=l{PBL;6DS``|)w0XQCf07mu?!An3LEPw}q&%gjb59|h~fRn-f!3p5j;Qzuv{}Z?t z{2_QMXn-lO2Rsa%4DJPf75o5({QKZt;GN(v!0&@+gC~Lu!RKJ=Ujhz-ZQxhH=MV(E z0o1|0z=sh$90c3Iw-Hjj9P9;m0`Gyx?1Be@Pr%PU4UB=W!W+L3JQ~~;ybtE|Ft`G& zfX9G~z-i!a;CsjlJ`Mf?JQuXVg`f;>3;vO|{ZpVeeu~<;=aacld$HN-OwBHK8waM^ z^~cu6vU;PFjc22$4VC(Z=`6Nt#95jf%jTDwRR!Hrc}Ns^?r@{JT%UDGVnk7I)Ef(9 z*-W#!MBsC)OYK@ny4qcxTB;tdwWd~Tt$LGGZT4G3kl99KYPmVrU8+qTsI?lkr98|? z2xQZkYc*Gz-A-JzElJsW-$!XQy313w#%yzr(wOwxKBtkm8 zcy!CDXM{kF#acCgb!vD%r<#NsjGlISc-QXEw5!W2OSIP1e08?dZ29b?1XRZPTD8-q z7tHG@er;*q{+g@JXH!$OduM8@Tr%CbY15{>JJqeGewA!;vEI%uym(?Vo2g}!t!^zl zv{-9o-8O~D`uAidPiMB;ZC97FvU+xKEW32E)>*8zvQCruwamT6eCA+tsoSYH8+2`! zThnB&rSsHODpMQ=NHaTBUs_U`NNFx>HkJb%}Z{ z*RbZv+O?W$Hs2-nrP_Q)<>}N`J84Wa8#2uXr94-gthU-oe=G;%mAmgXsVKwO*UI?sL(?SUsA37=a@ECRAJxnW++nY zZlf;qCpuuh)m+X-)V*;tq?J|aOuN;t=+ci0+-yfT>%_H=Djq)A@Sdi~(jw zksCl7+kUO1jFno2A>3KCxfh$nwZUK*Bq|n9>FO)GObiW;#sbls^Vw{-)nYJYGu6)Q z;-G5DMC#ZHLg7n_>OYLug%+!WX5l)UJ1hGhA#!n$(S3dnQorSD`+#PTujf%)-PF6} z3m11~-rrf->>^BvD!IF@MZ;#UMlG1YnuMAwVTk4lDq)QSvpg)-m+R_oGpjWtTlIy- zHRd-fm1b3^quOZUpeCV3DT%LXo2qtd3x`#p<>~=jl6-O|1v%X8W{0Zu6G2S_kqpsm z?1f}vR(*|Z8kv8!)s@<;N^2W9UAxv;+(=$b}0=^=#7D|wL zg3+7I?(}gj%VxMoePL-4W+DqQL(rI|8jIBSV4j(p(bEM!n>~30)*C9Ke{UjfYb~_L z%wHl6W|yig(c8VLYk8+7gLbhwwkTn3a0~*{Yyi%Bw!)6i`d(jNfqXP`V``;JyRk$J zWvZW9$UI@G$1Xl(&FrF&H(6`6o2{X&(J0xfOkl%Xo7-J))T%8@%h)HB?BM8J@yUpL zvZL9ZS*lGOs<#(ykaMc-`s^Os$o||m+o>O{YCk=vx>RL(C2dogHrZ)(p{71@4a|3yL8*WecSd-?%Fw#?cJAc-@9kWuE||{_wei7 zY}=kkW*6+*vtuZ$v9nPURxuXrwsv?LVyRx6Q-S8Fa;rYm)$-O;Vna_w0ibZ83N~$< z$aYO^%Ffv~v1?+8G%ju!lA*;Qs<%i7(yO{c4zpGInHn8Zo>jIwyM(eg9ctCJnltvA z^Bo$tRvI83eHJ^Nm9fp64;?yGsbc@VSgSN!3mATz7Hn_cy=(i<*R&9622TK8umju=+!x#jTnn#%4cHGZ2Iqhwa8L06 z;OjpQ-T__?UJPCYCcy+a7d!}j5#Ig_;Pc=b@FLIxzXcu$ehQ!eNpL;520R`d0FMF{ zFaU0Fa)8(J{bcYY@I6j>Y(KZl)HVZ+#7N95ZQ4{Q?QUO+C1Xx z*>Xv-2eRMlwk_WSw85pbM=48t)c3KDu?qqUFR+9fqtzY$&5phZPu{@ZFBbnS2>3VPH~TaF&; z%rnmW`44Gc%}pGa!r@ALvAR;zAXZfS%e41lM*4Dj)h!pR{Mv7`QyS&1L)ocJ z5(Fc&$U@N@Ay^l8JSE1AiN8=Q>!p#6vWh$`G2(Y)?)zGdraSiO$RzX8#o^37ZQ7HTT zrhY>VwOCxfE8|~R4P|+K#;?r(fi4l8_I@9F3V0CX8~WDuh{%~>pG)k z66L=TRZmbtD6}m`Nl*MI-X7G{^DS1=w)Pgnhqm{7jM-`UQ9FoSVcdR&tb8fyDchyW zCDq7Hqn)a>yUVVE(-M@Vv%!{(p}T)er9xE^@ORjImbHX4-ZM#etmLPLAa^dSuGT59 zG!Z2mRsV0xP%I&GlR(4HDl-Sxs2zLhp~X7ly4sRC6dy#d`%L#Z0isWSl0A^7aY9?< z>xhQtGdmlkMr%u!Q!)JiF#PhDf%yNE9C7nm`29D7r-R3WM}T{SkHX`>7A%AFz^?=8 z{yz;~3tj`B0%pN5_%{6fx4`?stH4!YH#h~{5Bv;%{s!=IAU*$8;4xr3I0xJT91p%q z`M%=H`6a$T1wIDe0A2&04rakP_zmzw`1>z_4}f=q*MRq<^Z#4$Hn0TF0Vjg5ApiI* z_zbuf$S&Zu;Q62dE&?Zm+k&rRAMj7$72s-c5Ol#M;9_tqPy*LuH*ht$0$c)41OIj$ zvJLPB@aN!4Z~@o??g~DLt-yQ0yTKfo1v;m28u%P`18;M-1HOyR@Skx)REBe}WHk(M zh>ModoG{`%iR6Uk3~sX>XzV_=)@lw*JDZgy)a#(ZsY+<-p3fmYLs3>rO3*;GJ7wlL zWypoSWk($i7D6)}4mt}0f=iU_6w((Os{$WwDC<#_@W|-6$L&{?aG2aGSS+^I42}&1 za-xmNh(^B71|hL?4Qh7uw1^;8I*mqUhc!_sJ7q~BKqTod7hd)03b*fxgdHCqDTGhU zZ0SnMnif8aW7kzAu2YWAg`S+Pwh&E@tK1V-;7|3+q4_sjY3>NEWA`35lqob!YQ}xF zrHh)#T9Nl&C<(3T`)%G7APiGOiPBYeq%e*k@k$y#`j|xBL?B7`SK)oC%^SChwJDci z6t;_|pq%q1PW`c4<~nvK9IWJKCbp{wv*EL{J#0r-wWI>ssk4^-PtTO|HZNNv!CcHcoWUVV+@g+_@_2B+QPW z@YS*YW9oC0Oo^=JOs*uy(55N%L?Bq~*f~}fMm=NAD@t#K9^B0lYro@uI18gP)Nqw) z%Y+<>Ognl8o9LPP)-{pVAfz~WC?VGue<78HcoK26Ihz=GI?GQwq>uN<> zS+u5Q$pD73QNtpaCx@~ zk+Gd&v7_I;5sX=pksOYPKYN48rfqz+qObCnh$vE?Y69x<_`X&T5iz8&`V`7YR9ix8 ziI{EPY;Ac=p-Lu3^q-T7BrzToMUa%lN*R4537-FSk_;eU31!{ZJOUkdN__iRJ!dzJ!f19wSpcmdqy!To^~%RBLm^ zQ2mBA5oOqw8?TgIwbJwloJFh_R%g(hby&i~PNmYo@W{Sv2H6rkbGy#^t9c%of%zTI zC1*xXwGE4$$zma{kzFDv-Tupz@a2*=!z@`NG?~y;+Ow_tN~dx;mTQQ|uoj7nrNN#GRl1Ni&5g9UJVa69mQc=|U1 z+4nyJ$fkc2cmTL7xE1&Y{QN(Hw*cAnH^44%3b-%ueEwg;$G;26#{b1&C)fe*4ZeIl zd4rdL!(arQ1Re~&i@yH7;5ncJCcquRtw0I94xazD;IZH^cpxzKHyiQBe;X~Z(E=MS zu+aklom;?fpN5#(kf9do`&1s8iXk^-sDAMLk8CB#;I$;oW-{r&4H@b!e?YQFl!X$Q zzS)qW)>^S_d_#tMJd;weyo&(MX8Mq6$202Iz>;$I*>HOb`eB=Pw z3QhvI1)oNr|4Q%*ARYhN;11w;@OgCk9|3Oy3*bC(JMg#Y@m~g>4bBJWf%}3V)ArW` z^});358>k}4gXz`4vw20xA3r6U>by}KR3B*$dWxN`tcxVCqGZ<=6s-9l;K5?;WG6;`=yI~2R{|vJ4n=HXzxAeNG0qK=or*>kh z3o1ivu%&He+K9TPyJ=7RBce{3lS-$Dy~LigtG!lW!6tcV&^Se~Zaa(~(Qaf(Uk_&? zTQuWXdRc2SMpR-=ZS8%gfiBle2)N z9ky8~C^}QO(iAmpolO|qyoFYE&YHEFTV<+kQ~~DH=@weG)$A@T(mGLURz;UC4AZMe z?$kc^d`IueWB^44uu~!}fNc6Km0}GCE40K~ot^h`WQ!gVM<|3s+Tr+d#>bCwV?JWe z$T~w_e?Zt;LfBuIg;cb>U{H(}+^9r+XjhTKQ`f$^+ZuY@R9;4`EOxw}D~G9XRxay9 z#b{B+IFi}!B*nT+>KU7JZF!}07|meNf5bx^QAh6%nFdfR=u`BE(;H8rwEa4jsq<8$ z)oa~hQ87mTb(Ek!)+D-Q{&xT6rD=~ZQTwp8>%E}?WDC?bYSy$oX<3qYXeBqJMmAduXca4MSG!L+`@N8t0`{1LjQ2cABP3=>bd+5L#e3eva|8bngxcd&VlE&@RjJwri#c^IB=ZIE#Qu6tlPKrHw7;i6V z@gpinkJgTm(?H}B$I73lQTr1Q#O?ngjqL1)AO{WcQ+QeQD@CF`BOs09T3fl>jVyhg zw0)aUz4NrN!1x&Gvk>&S_LG zN)7Ww=j6-So8(*im??P9tfHeYR@X#KocNl89gq zrx=GfZhM2Mic+`dMtloIElyLhP}VsEzo*&RC3`r_g70j#nyvDtiSCMVh+$mPoYdP8 z7eiKNlztKbHs{}~OT)W-8z*>x7y z4{vTc%qAuNe>eQ*x5anD|Iawy^Izcc{||U8_+#)iumncH3E)5B`M(aX1@8ne2XkO2 zkUU@zd0eBpE4A=~QiabDP1^ycRF*pbw3GNN<41R`e;NQV3!E?bmU=z3-_-kYX zO|S`k8hOBrz@x!A;2X#RJ`SD=Wb=PKP@jA&l?MdLz)k#1Q^tN0V~{Yy^^)l)Dm`K-W9J;Y+WIw!S_^9@67*KqB@RGzCJtc&Wl91;n%v+3CU zY1*+)C8gs(=Qj_*WH%0H_%k_JXYuj;mDZeKJn4L~iRD zdHLyzCh?%H?@SHvRGfO6#kUIl2c&6cDrb3`*`zJ2rQa_5J{YU!zo zFe_gy?tGED33plMm|t_gONPllND2a@Qva7pJVW<@h5_#@!qT<60P4 z6^bb#=v1MYC7+bOC#irj>0)OxCY`3Uii5}=p`y-mDB~JHb)U;H%&o?4W)n$=%?1Z( zT|}iEUN0{eafN*ty+9v(v%INX*tO=-gdU%$BayexbS+~#NfQ=1>Ws&x-~CRH-WO(W z?jvfyV{%DNHhWRQ%?KL{0ZkjY&|*c?;C*@EeDoGs=KO^S2b-tZ;=GOgnIFr9oozb< z)58j?4IF?4a>7|V8{eaLIEQzJD2Z5V%G@_xV4^k2`moN)G0V_G3aw9{7eg=L$kN_RLWyjv8a33lAna( zp^V+|G5JR$rKETMGhjCT1;lN*qBIUF^GyvuVcK8#U0W*o@cJIeX$I{B;zc}*>1XROB}?O$xIMi?|54#Ku$TB@`)tMw|UH` z%^BsWcqQAa`ritSP4lUj5eU1ErqKCNG44NUfhz9yAOqZ1ZN}H3W}cW$Dj~As8p8k5 zM|xrsfQYq@HRadWk$k-rU+BgjfoPwSTUx{{t0I{5^4nC5@U(Yh3}a4XSf@~;@u+WPi2wf$c*$ou{@>rv_jY*vIdD42z;*ERuLr*a9tyq*-~KYt232qm z@W0{R<yk|+#j3(?hbwk|1RHs?*)$qD?sP=zX{L&Z$NzgtH41} z1!sf%f{(zfzYaVfTn@&-eZUXk&A$p>0!{;8hZp}{Py=(|j^Ga9AK=5U1CIwfvwuD~ z9=s2p{2K5)FbU2DJHdUxuYABO>YZhk3uzc(l^fD z9EEZr;#M9EYi>_L90{`h=*N9+E; zUYQe$S-&;&GIl<>lPURh^xZEY&2H&-9I2MmQl|M3upFL9B$GXCJjim)>B_;%{8D`d z;xNg!5G#(PWGV;iwL|6HxPvZCqo==Ak0@+BG}g$}onf@KtXfTA9L>`RoPo6D!Js%l zhD4z`)@Z#oj@TuIp_+KS$09>?HQOV~cq6o0dq`ojrXl5S?65=28XTm3`9V8N3^-t15?Y%2G8f457Xe zO^FIQBvdyR3*Ji0s#_|4(DbAi<>`k^PZ-0VnxSo((j-8SQC7ysJuIv_FhFu;otzYn zLc6~27YZI(G5d(t!bJ4b`>4iO%2FTCvSTTOR9{Ea;p&ap+Sqy^=r0e7Evz1_)|ZT5 zaHDjST-&&qH#R+II2#X>U`hY$gZ45mPHYN>X~d{F~n_B>-n~ENJoN!3<*m0GtkcQoSyfrM_S+3R5fIj z(@LL_Wg6gEECO znulX!Q)Y7&H@=pMQPG~v{gyPN6zbww=M9rGOFah(fSSl$eqt|p>YcDDJ!hzVFU=T~ zix(+unpo+1U6dvu4l$smC9*QYxOR{)AE}nwoBHfH|`1-Uti{iyOm}*NC*6zsa()q6LA@l2o zmcCpn=P$0(f?Zg_#Q($omfj=&6#YL-s`M z0d5U$Kt6CCkgvZc(7AvK@K7)Y?gBn483EAQfUCd~(D{Jh0>j|W;GdBfd=$JAyaGH0 z$R0pv1MUEBKz{I6@D}jf;0a(JNdI2}_W{2SZUg=enZmz-4}&)Y`3zhHBj7~vA2jau z;Lm|<0uG22+ zJl8uG;?!a?H68C67l0V2%9bEE76FVpYvWrq@O5jRS+-hD?x=wtu2!*Xal9G2TVE16yJavHca2o? zm36K7H&3yTYQKk(Il3wzvAimoGAFENdTKNU1uc0NwDw7Brt18-!`G$0H)hz!Qh|T$ zRrvI;p~cbx>}1Ge*+tI1A8eE1S1n;?R32IJ7?__Mj4eA>@}a@Lp_X?%h8@X$qo+!R zi=a_WL}$=#jnI3O3VWI-NF%#!t8gB&*3e%29BP~F9$xRgC2)3epZ9s~C$GO+RP!<) zLH*~Vzg%f;#W^mn?1VddP}_G&FS<5}3x@W`^kN7tHqic(7NsC4RNtO$r7cI$ zqEq3(ZYYYDnQ{CAuPj#OyO>|6a3}~qz&skX&!(DmLP^2(VR^IRu_~lB(^EHIo*4Y8 zS^{TUWv^dD z`0sX7P_y#bq^vtd@?4)YvydbrOA_@Ll&-s6$_ZL0q%(-tx8!T9@0;LM&Da}hq#9+8 z{s^N;;^K$u0qo?Eha+uBo0Fcz%n?9IvUU@ij44jE~pLgxv^rDZU1StS=1jIgZO z7YLeD@3M@0EiW23EKbzQeEdG)e@}RQeY@Fh;ov{txvba)Wv>rlPAeI1X(-$2A37^# zS)t=hyyR2Wx5jO8YK>)bjX1F;omHObeVec4iaL7oglM3hX=joZDHE3vJEO!WWTqIc zqqcY(6l&g0dD~v}UqS7Zd^v6|xnuV8FeU$sNt3)$d*}=e)1HS3@zJZFG|P_LyqU{= z{EmmsIh_@$PjL<>*BeqpVp?-hF3nHno)eu$jvk%n9#v=B$2KZO_ zc-`kW4ekN{9iIJt;LpKxK?gh-+zZ?R{9pL>e*&)u&j71Hclccb&IGda{|Y?2{QbQJ zycs+V$kx9O?hk$jpZ-(ukKmKwoj^SO4;$j4v?ZeRS`XYNpdI_TCt*8=P@qI4c>Ma+-PO?zU<%(e^=FG?(-1T}r;0 zOt_E&A)_CsWc$r=8gtfC%85W}I!#Xc+9+d?yM-jeWOfm3sXN@+iSomwJLrrR8}YLc z&lE7q#+?u}QS(})?Y>_(r=l*dt>};RM1j4Dr!MKBRW^sDqcgN7Qz!E=i`0dNZ;*{! zAfl{4=Jj}@50!Pss2IOp7RE#`n2yJYNjl9$gH z3o9qX8iS$2Ks;zq+T^eG$Q>=R!o@Y2ej~FXouSm(a2T97m&eb`)WhV*(ySxcxfL0f zbr@^jELi|fD8ky5@u2|e0}Z0QTH>k5kA5%%NdMg|S%7kMy}L!dOYZxJ#o|rrBQ8N z^6=D)Cst-4%dq`<%1qGKtTQij_wOkgokd>3x=LI`+=!!9Os==Oi7e$b3XT8GM90rK$g}KEFhK5?MgG9=;l0+YI>&(U2vZlmlyfnO6iDa4A6y7d1pf%%|5C6EoDa?f_X9tK@BbWl2Y4tL z0~K&@a4&EikWAns;4R=K;D3P`a1yvZ_#rZZ>%kkr8-Q#8o(yzHz%sZLoDD|7{lIO2 z?hDX40@(!|0K35u_!aO$%Rds!B+5j_<8z9 zo#4*kL-c{^2lYWbXsltC?!VJ1Z3M&3xuUUq9NY&BTi|{}l*1*@)mGUF-BLGlS@Jw= z^V(CQ7zmM?=8Sp`il2>&YjMo#Whs4;#XfgiP)f)Q4^PttRu;eZ<`x~uGYJL8B2Ou- zsY*LqYy+i~yeM9%7Z3JA6fL|V$M7;Ornm4^Ap?JVlNWduR-YTazE65Ne5cyt^cM{# zXKJuWKS$}kl^d-?a=)?o_ZV2&ZnyvcqAmV$*`i%;!+vPp_Ct1Ni1yWA#^Z#3uwI3U zdW+c(CGwM-<`(3Zu_Y=NELn`bqmrCi=;In>!?wt-@dewWqQFb@H?=B~+-tq6HhVQ8 z7kgXc{FmCSXq}o3OZYf}vsKeYGUuD4_qa(WMV4F&gi;`%M5B03McF=X4~Bq#;dXB- zTselJw+on(P6z{xWZw{4k79weVP*HLD4?R38oBm5h#@$>440SOgu;d=g zH7sliwjMj(o3qd6)B2)Y!UAILH8~$3H+xaO1&trX|HBfOu5kSSgya9e2A_W&cr$n& zxD=cL?hHN$um3UdDsT{V!KL65a4|Ru{7Ae$coBFY7yyz3d>0=7I`BI1TJUo4GH@=q z8~A7V`M&~B0!v^ssDN+4%fAQ6M&JrC0=`ed-w#^ge&7@E^X~yokb&#q=jFd|Ke!Y4 zCj9)LfyaUsFbcj0KmTUX0Lvf)(&5V{|2M&%fb{h*0WSvI!2Q5|!Ck<0@a+dd25t>* z1?0Eyt)K`Rc6EdCphJN`)NmQ-w~5XPDfsjfLPw$#-kR+$v;D%k z790!2#G$_fTF<-`OwZh@mO-2CT(_kobU`xhGfa!Snw8_`L4d!jVIEi`Eu4Qg&-G}6IL`fu5F&p=a z#6ByvN~g&U)CcRcy8qVx!C~;u;T7{+-(e>=Q6hcF3iqly(;#khtxYv*3&`CM*8IWy zIA;FR!TM#8D_bI{o`XoX~@w{lU^rK zYYjhQ_gJ>g8o8o)*^LsWTV1H3hQXm!J3FU#Sf{hA*}3)A+T8GNb7EZ@V+o1FMG5Gf zTt(4l%+YdfcRRWt#L|X1N3n)=TddD_%n8j_efB`f)Q%*+Q#h!M4ptVD z*fJI;(Z+IgT}k-z_b(cb2 zu^qXDht|g)XohD#A6hl8MAlI(Utw-* z&}w~19+k`Whg}|RHUoJe4SVS~4|0yH*{bNcwalD-8hDNCN!2J7Md{@eU?4NjnTa=D zpGN74BG8vMPl$Y2XgHN>b24S0zB~`Tm>y@Qqse|$p<{oL|i@R|W zuJ620>#WIHwkqw7G+OM99d|p-R#5WA(|1%vBW368JZJCyKx3kr1NFg>R&jJywwEWX zcYRy2LXHx{*w*P_xvbqvir_!aoj0|YZq7JV?OF7~@{6*)6GlKNs07^J*Fh*)UgZ25 z9jbDQ(T{ocS;?1!{aLHCRYLpm6YbLN+_{K$YPJWM<*5~`#V~ud-F5r8V`JK}5aCH( zju?uYa!AN_2g$baJFd0MJ7i&@aMuAvT4LY6)=Ow!3-dUb0oybqW8f8kg~60mFkzLj z@vIAc%Ig-vkgV}3-q)MGT{e5H_0+61DO)K+w>!(}@kVr3&QPh^Ae%Fnvjikfn;YAy zOiUO^EA6q(oAJ1?*qxE2W^=8* z!Yyq6Yi6lAvl*t9J=;IxZF}>|VbiTsGZ>*U2sh^@5DMpuQE~UjbKe}6G~0}9r@oq< zLtV3DMax=pc=>{>n%OvQcAVN@fbG zkovgLTT<0!9=dSuD@+m_yP$tRhMv>@%j9LYLAX-|)o>~tQT|S zkAqi$E|4w2Ft|UsA9z3f{^daW|0>uEwt+!#cko^K{x5;o0LcS%W?&nT55Zf3ze5J_ zeDFN53k(3=3Er5PTe2!77mK;JwHQWZORr z)b3wTHj!ch_!5f!VpSwmODDw;&Va| zK5AP=Y@_({o-}4!&Fb8&);;ShpDVvgamKQGn>tjJ5lnl`i*HS7?11#|!bwcDrr8TsoNRRqYe26Q z7P<>pQ{2l?n`1p*smcS3QyfyH(|gC+L$1hK?UHman$0$0+{_ zm9YfAt%e9V%UcZF{#dEvT1{*L^oJcDVOaF4y-bFEBr3&o=8Rz|a=mft>d*fY(h>{HYJ=AG_3wxD_iqMHB2$`30+UJgtcpUPwW+$tyILt_r~ zx0yDw021yMOXu|nX)uT4Y@*Cfj%Kly`EDf03cc9%MJ zwZJw{s@VzI!_!AAYUVz!h>jH}KNj<_`Q_y+p1%{9s)1T~uKL^9X1%pu>Zi)KJH^8L zY;`&7!=aR2rT1BHp#P(c+nySKJN7 zovT$z;o4)-mXSAIHQV823zj}*5$RK%O17gmU!_80YU4n}<~B~4QSGv9hmf7SRtxOq$x%awIv%Nl3YlN(5lVXTIj5FfBiAju$Uf|Zrily@cE2F zudbmPJiU$%gg!~-%(?okSFX^fZ;mKgh^HN{sNHtt`lsA8L{eeGDc<*M(R!HM$dSC5 zw2|d{L8@^=inLURn6jOtTf%0cu|1qiSxv*}kv{zrDVaC!zp*gqTyYt#F4q1$NZAr?tX-(X5pIe*|8$=J@|}9p5RP|C!*6@bdo~==}fNf$sTx2#{~T*Tcg< z)$#K`fM0(WmWU+?(W%Rqkj^zj?j}tGr*2WXoK&QLvY@9YXSYo3|VBs)}YozSKFC7AT77JTGJRbxvkl8Du}OR-??l(Y32%Cx@awJRzHG8Tg9^?66VbxSRG5b9p|44%s zS@7}VgEY85o;Pk*RR1dtV_tcxxT5=cM9Y>Vw-YK~ZzwAS>l@DPUVZX$boG(5i31%W zm#~tQR6_r;n};Xz(lWUHDSc=}5$PkXH6@>*zAg2ZKZ-eSL#P=dUWA?{eR98bd5m;D z%d*`^I~8*jgUmoysd%{T>!m@*vN4$?j7`fa)owYhn%3!9)aw2GW#uJRdb#f`L$meM zEGd*eNzNfhvNL|Je2z!0?oo?0obBy)nC#q`RWdod|JVhBIFNPR%PgyN@-ycQ2=b2M zT(LyAY3n)lE#zzSkd0%)&$-=ZdfM=^)6;hAh`cU0O#8gp2kQrIojZ~!_i21**oRc*~;Zz`DM_J6r!iSRi9d%G2t(lU1od|zE@78>Hc)5BQN4}HoI8d`i zvHfWEGa-zwBtJ8R6!m5fN+sdT;|3;=lB~?^xRC?VmwxTaCs_vA^?2jG+Fzm@sF2G&qf^GB<$P zC9ImOYt0^DRAg3@cbIJyI~61B=~xw^=kcJdT6L0uw=kE0nItQQ|A(b6oi4r>{r^tK z+Bj6d^cZOl6E0%_Le4Hi#1wx2Z1#-=X0wRPltQyS*yv#A?V z*2b28EVp#63I1c`8aXfaOJV4xd3zI#tmZLNLi%9XvSDskOZOAC#}Q4eCgrM3Ty-Nh zuyUqXDw5ri2$34~DB)h#=~x(9d2u6-F;`Y_Jny7}vhb%AUU-pK)5{^98QqY~ri>=u z2qoa~?7<3kj0J;amuZqN+o=_Fl&$7Wb*8>l?;I8xNmFr|4+hiN-J8qC zv`57b^wc&^7?~Hy|5irQAm5mx(rII5Lf80mfO<_crjO5=ijIe%@3aj&6UG=j}dHX&*{ z7>fz}w`-*v4M%(zhgAEGx+Fk8u!=c40YmYzVfBj--S%93Eb+Z9ic~g77Ov-mH|X6o zlG_|n;GE?YS4U%-H{Q@4e3*-3M^m=?D&5E#s;ekoMQH+B&7OB`1cvR=9} z)|g#|%2+3~t~3?73Qk}o*x4v%Ls8-wE<)aqhW{5^`xN~L|Nkh*|I6>6?)dvC{QWb* zGe8I25qus#{<%PR`aJ}k4sHkj2_9ZH0M7)|K=%9(0N;Rrm*4+uz`QILV9T~v;!S8|RfQNzmgP$P}_$+uMcn-K6oB&_*_$w@hF z|8q5r^f_cK<;lfZDVl||V-@XkqRptSq(edJb4W$E=C&*PBUrlW6;wr@fjiO2(6MLG zx~yJjgp+v+9_EmK=~IYZG7sHQcp8sg}(a#I;_Bz)Wv1PtCBDPR;CRz2h|J zSZYp1-n>sTRbKuxb%?m>rs~rf$_h#0hkuUHDtXDr*SF1=H~A*Z@|QM=DWcCtWM+~8 zPMsa4UvH?&zCf983d=$G_z`LBADJYRG!IZ7jMRJzp)uLG*WHw=@v&-7rL|U*9--k( zpT=#Q)bi)wu2-zpu8z=N=OF#Lw`|&MzE!uRk2R5!A(cg+V>C*Aiq@rjlK+9Bwby`n z*jhicWjR!+FOOqSl%%2ze({!sq=*~CmUVhv3qsm#zW~Elqiew5(lKnD{NQzk{n1(p z<>jp{vGdqH5~8}gR5k}0WXVkOr*e7E|MohPSoS8!PE2s6YjKz{z(T>{7Z_Sl0F zwhX$NlO$^$CJgshp-_P!Nw6+Ej$XMo=Sp;v5%&->qbVD{g(=H5q&QdM&ZIzPiW&>M zz>PUq;S(($smN2yE89bt<_~mYFb{xZ2#BO8K(tD^g5)Js z$oxkx7RvMPEOg}Q0<+fkW1~YkG^^F%MH6UX*1zdU|!ejoH}ppjL+U%$|`M5 zORNpEG-KCi)}T4V|K9@^R5tGM`TuXi`@ajk6FeVGgA>4C!T;+Xz#2Fk$nU>&|6d2s z1HTO(4Q>Pe9p3-L;P1fm!P($J;KT6!?*>txPH;Xr z9{dG-{yV_$0`dJnhPVF_cpZ2x*asdC?hAefZ!ezzdT=fHFnAq!EZ7Wm&i@8@{cnK3 z0Z#=NgY&>`z=vq>KLKj@m#FP~Kbju?)4OlkWqy-X9h(XgG(}=@OvJE>YpUqSXd$lQ zKCOztpT%eK>+`f??F^@+cS-s2-m;p-G#2IgRaTBdS#C?dn)>dO*BT`_Sg#!lS+j)} zE1mM9NiJ*Ubc(C2YL`Uly%uSWB;tmdDx2CUF=TjVV|G23HWFpe1NIYN1mQ~@a*gVx zT%>E_ATU+2{Wh1>kY0*6_7KDEp-Z~yuTSr|P0?(^xDQeK`t%8Tq<$>so|iEb(@`=> zu9>Pmcso_q73}WsnK==UM>=bgrx{|ngID3CL;foD4HbGX4?x{4mBtgq?&T|{q>6NB z(mAX+xMp#}S&h6z_Qq))&9LcbGwJ!_JfO>-iXS0+4&EPFJA283BZ0Ss83!tO8(Sv( zvaBzLGH7JHGW9g)+!Uj_5dG(WOnV+D5pou4kFvP-s>baV{5FSnWpr9DB5R{;u7v^C zr=CVyE^k%`LccQzeIgcu-(bPem-$&ck!0|}CRz9|wcGt+X zyH=(vL|Dmt8Xt(0MutG*PNPaZ4 zVpegd5}PtRg=~DI-NLgDq4C{<6h(sd;#zn^wwn&D&URn~kvQ99Ea*#BV*52 zgbtDD>58bUXdSqWXa=}3p|ZJ=Q8HC^|An&pr^UU==IL${>)iy~I2-NsY-m1`xEQ~% zXPd6HnxQqsBDANuU^}$Bs~h^>53rLbUvW-VTw#`XG+h8lCz= zm&@WoZD*^6ugH>l;r{DSV(Mj|n42bHyJg|$PychsCoh37 zq_LxtkH+3gMsASeGSI$$@&y7g4iL#fSTRcREB|+KKK2A+umoaK#IeniTd#4Kx<3hS zWP5qNY|Cod*&0Y zJVvPeXyFi_Lgv9(IP$;c;9cs$VEe)k0O?XUa%n&2X^7u*9#_x}m- zcyJE58@M&N75E7J{j6Tw2k`x00M~-|fos4#xGhlo|5<-Ifa?*y)Mff`DJ&Lz?Is?x zSV`KjmZ-D!lQY`Yb`$ZX9#{_?Ta0|JwTz*O?u+fNUb-eTg~pmVph#?PLCOYB^WWwhmI?!I@X1yE{EGora|~GUExG- zkL$mXw*opZcjJ+t8$GrbH{&B|-~+W5RWuI2i?3vKd^dMYKI%ONWn3Y?~%e|Rl}w%L%G~>v2AOK^-HRg zKCU;)hUglT6GsNPTyPt|=kBe>v1_P6VyL{ZNFvsH;+c0$7De>4jk_I;o1k&kE%MwD zRnFg#ae?jKoC1iF3Hke~`$!@gT+bKga8LBeS;Q^m`j{*ke%vRzS(*E`bFPo=Sn^Np z&=Zfh-AD4yj%Jsb6MCG6)mmX=Nds6i{pkL427_jgw1Q4fX2y0Vy<}LwC{cVbD6Mdf zc3En2CS*SsVJbknOzDYXpd$VtXB>{Da8wyV6ziICmSxKBgdbLO$uiCx zjx-{YBdR)URJtnL1v?fKMn>pfPnw(TY7&GN34f`$Q7%O`?x^fmuHg$wM#)j=KzO?) zHPEAWOym?PtLpwDK^tU~p3@96*EAnTd@V1>E#k)u3g*EuHbYShwGY9A2ly2PZ_v83 z<{S5aloU^&s0co8Pd_=M=xFCr%(M)8vUb&v_3-Rp>ON0n0ejh9l7}yMBV(8E{gTS< zKj|+$y+!9+$>;08zC8eVeMGdP_xMs!B)768`x_q5Nck{i} zWUI3z_)HCFD`BHYP+PCmDx&YE8`Z{iX36Y!bH2lc2?+MuxV?tRma;CJW|9!X-j<rZWCLsavn+M#qc5Km3L-s+RB`o zpIwZKw`*Q)Ap?Of7bPuZbB7z%<@ziqgh+kLFQ=xHIRIfUgQ~)bEZ5;k!g}4ONqSaU z*8%+(e^CmFZcBZT4r4vJWmn)THG33|e3jEP&qx zXMwweAHny38GIPL4g3*!Dwqe42K#|*0v-ki!L5K~1b+&i31-0zm<0C+A44YaPVh(I zBJdMr03QWc0^JW-2gBeq@c+LD&H>+t&%Xi49$*O^02AN@AQ`~F!Q=la_!ICuU>e*F zdw*EWm`g6<+>VE8aP1jbE|2zow2uZl!rv8aQ~(?H)#pw-6wv$!-UKSvXLFQjCGYY z+)Y?_diw=Qi2B8G9yGSs%6b0JYH7pLq9jbs=|L`*o2f48UdGl4#ZL1vfeNw<0;xlN z9-}9(mCHbO-tLRJ+e`23?d3G()m-ri1%l$y>6Zu0X{wT~&f`F&ZtM z+VTpg)s|FwC!$xk>eYO$%57kB813HVqoioG*QgzGLq_tD2A)&NpRFJ(55$}|f);d5 z)z0io+&- zab=MYMoy)I5==DjdTtEnOb{+B<*iO47TpOcDhLtjB3v-igI07Clc{r^N?0P4nB3wSe(}_2d1MV z@@1BT)S_G@4C-<#3FoQJ@C$9G^pfntQsJ}ee+8-;KAec9>Fh|NP0~da?p&0UcW0m8 z$}JA=o%wC6g!e3DFGwXu~r)Nh?mz$kb_*GYNsa6Meo))*NR!fks2uf> zaoT<~c85hUk^{S@I%+yJgDYA|mC{G6WDq5?hM*&jWx8SERp?$1`$52|_mYX+f}-nu zwJs?$|LJ|&XX%%xVtaqVcKWYZY>m-H{UDyc@%}xY?Pzv!(m%m&&0Q5??N8gnWve6l zh}g4b8PJ=<_6KXaTI4i)KV1H3q!I+&efBSvZ(Cl5bKI_0UMur88hQZ@#7cX-w?))e zUV-4g)1Eu38X*ndKZo9@uBHj;vt~y`h1B>i>9;;9?Gx4y&nh7jE9`a1FgHi`O{Nr% zQj%#xqJUVdiHrSPU7f1Vw8uy2+ne3nBWv`wt*_@~od(?Mqr%cOAj2Tfzq0UhtrDD` z(M(w%p%U-gBQWOj53z|x-?qqK9P|*gO$gj3&n~ipcABM-EsGl?L04sI%gWVemZ5hV zI&+zb)sY>nW2~k<%0@DGAnBIFyrH6`xsOBHZmDd|w?y5lY4@syCR+>ZTGQm)t}iSj zJgBf!Qwq-SS1pMFv&~pI6pmGLwMg|a7w;JF;M8WRexSyVFvC`G0NaV?AzG(RmZ#-e zZn8q&P3(h9asv|XJCTgZ%Kt^ z@ke+idRk0q z3>vX4(y*u*^;yzY(qA|yF+qgRaUZ$|n-4>0n>!GK({afiEE!_#${Ukepx=jDcSRzY6{h z*}x6p8{mWBdEmL=k>E6-djsSH@T=eukexsSJR0l)r-FNf|3n@jdx1BBtHBGv)4&RN z5V!%}|9bFbPzPs&2Z0~M_kSDw19&@l8+a}_1ReoS26qJCMjr5a@D}g~KxYReBiICv z1M(m6&p`JS$cMl|unXJ={1`dG7r@8CtH3kB6ToA@UBGR?+mRI<0FMOs0-r@j@J6r< zX2A@Q&A@+5My&6~~D#rlS5Rk`xC`tuaKxWUjSQYb^%21M27sq~Vo z6Wtl^JP?<)YU4;%?CCpgh{g%N)$v%O(6m+jAX1x!@~Y|Mtx-md1={D-BJIP1*XJrG z&$=IVFUbbc}TEy5bl=faFYuG3g5Am=pDPfZl8EeyoTWjITw6@TKFQTOLe9twF z6(6iw-eWpWiK$AOH9NUo>GD$3{qmNfQWz05FC=wlX`R-ed2(q7nO$*fUPUZE%H=Wy zs;{#(yfAmd7!wTr#kLbSOP;cCMHx6a=`&GuL4}BYtq@rVy^bzHOYLhJN~j#|nF9?e@W4IDU%(a3X;%Ja+83~SX)b?9k^kzcH4 z@C;4C_DMy*(ZgY_sukOzkXQc!V5W@rD&`C(1{0j@mb!=2-GSM#HcEYsk(`7)YX0o3 z@$4Zbl|`R)`$70)4Wd|U^7#$DX-0WChe*0mx7{_Z#JZDQPqyKEL8lx`ByI@JID=H)#G}1%NIrh92cu?|@^xe0T zwyf5hUs6SHKYSbvOh7b9h4^o+2&*xt+Mn-8-!TKbB8 zA^wSEqD0~rESQkRHBiKmvj(sIwT5w-YIuuWBHwl-3oHfK$cvKGZDl3pM^Cd8&Q#md z#y0FMnmmcx_8Ct&8iy6c@noyHTIWm^oTY+trS0a;)@6}#v>9l$$&(yW<~-1w9Je!l zWxo%(VaG;7)_IapM283xGtzCbjHoad5h8cRW;?1Gj_8`%$mUbo`Wp+KMTS){0~?+5 zR2xmOh8dU>eFq(c6c9S$gJ56kv#`Rb-K^x)pRYsV&^cU=tW4mhZ3r4S)i=Rk%Hdn( za&FPpy`?K*s(C4X$GE;ihc2qe+^+?%jcF}TDrNp!oxw0GlqRR)Q>J>fpptQO*K2)D zSEwR+^dP7!G6H7RO|?Qwy-yKRF>(?{dHIiTyxz9nKWs8{BBimRtWv35mJC>*hQ9mj zDD@_oUg=l-A;B;u;}`mSI96kFR4c#%fBF9Z1o&I(gw15X5Xa4wMF z|Kox9{%-==4ah&h9tjg^Ae%5Z^|4$#15u{K z$@Wm3jW-xp1f}m5RF`I)6k}b=G&AIK69hzQVP&l2x6>?+TCXFcEN_@uYru7>lIFOf zs%C#C6|;}1mdU73AWurz)-dKqtk;+8QeiE<^L2dS*)|EhFWi>X^@s;`d!1_r-;#rf z*mfc5t&tF#tRSV!WfC?=`~SA3rTik}e8o2jvwLC(R!lsvqoQv!ho%(HDWI}iWuXFZ z3}bw-+NyIZ(ur~v22qL*8Y8X1#5%5%ZK@-XT<2N^RxXiTOZ4Fq%68)yz^UslvkE*T z!&y`EQNfOi^KnVtFejn&lSg)hSb0#nY&ytUvtnFb9)bOipK*2!2EEYxhD8slnzuF* z+4;m|44h|df^f0@Tupq3{Ud$3f=z*3YRi9H{&yNAE*GKVpJ~7_P6-x%gyMKRpM21Z zi3z-fd%Y!lSds}p5h(S|HpGZZpFNC3aZ%uF77A(3fb`|B()pC>V=}ZO5L0pl_D1Tg z*)JQxxH!{B9c5Y>QGGl(O+pZN-Y(tWL&_BO!gl)#odg2(1~d|(nMx%SGyG@BQ_Tul zwq(<`C#PLt1YHb(VV8*&1=&+hoTuzb$y|A7v^PyMa_6er4^^(HM&{n#14 zr*OO`eU=N!*c#kw7j!~f4>24#=( zLmAIEb?Ke1;`a5JWMps*Zm|hZBi@~wP<8ouk}At7w%Bj1!Lm_Kn?q~<6%B?~WGGBI zC2U)QvA}%-w)fLk*e2CrV-77pGD`G!M&)g~LUyH$2vXGmQ!818581PYrJZ8OWT=XA zHpUoLPW3bvkUE<}DlsPoD7UI+1dS13y&#i^x3P9A zFlp*V+e(z#j)0OP+E3miY4PX}DxvfDu55jqC+c~!V3OGOM5&nfk8OUr?KJ`X^ODbB z->O1&86BZfiMnY|jHp_4j~nl9W=9Fc3r<>#20~+_1NH6dN;7ggJRq##$Qd!IlO6I$ z=eDM!z;=qPd6+D*%4OF`{!X!Fucr4qo_u_MQ@668Kxoz^|eLJ&@ z_U)Y5xo2|QZ?k zYm{v{)FYklM-+b=&;4kg?84+N6>@r&Q*Vr6k;QgQGTy3gn#D_d=J z?ZjtH8oE^$il+X|thp}RWxfEpbw}e-gWL!xEFEl8kZ>Gb$<98rH*Ie&&!9Ltic}{bt+~8|YUD<)eI*_WAOawM zIe;|hkX2w(e@FfIXe0ZnBjbF_tI^>#h^*0mzCvcGF`_3Z*!jV*jPrEXyqISia=|IS)Uhyhbq}ucUC&O$0wRld>|3Aj@onL_GzZz75?g4lp7y$nXpMNcQ zE_fQyJ%GEweZkk@_df_e0G5z5$L>tWB}s-p95OpWN>fr1^E1Tf!BkVfJtx; zI1&6*ygv9Va1iVP@)7VGKt2NA4$uEW;BEhPFTkIGXMi(-`~pa?e|vBZeE++^W5L70 zb}$NVfcKX^|Ife+!7SJZZV#@5=l@&qhv20^di=YA*Td^y3N8T;0bf0ivjX55;0mw+ zCcy;Q0(1|+=fNw$VQ>IE3Y-XT4L(8ps}J7OpZD*T2Ixr^K3Z0-P#~=NdaIR4GzzoY z>S1#$9+<;JDQ9TI^@(5YREH(0B&9ZYR#@-+nzJ?y9HkmYRQ~M{Mj6Auc&nUK+>mU# z6-}2z(e2@KNirOdKR6xs2Nna%Q)jW2fFoKlcphGa)5uABKV?g5gh;AjcA3nUh-F6N zS@VN(n^w{VLFS?PW$6IduG+3la%O+hP471Ispa}=ZO)&xEQfA!QV|--L^CDH|I#m$ z^SJpxe@wnqf^*t~PCzJ!yJd z|8e^@n6s3kwj47_j%<*Ul;kP9vg9Q;7oFPp zVu(NJU})F&TOTRc7O-X0zO;W!B^pVH0!ERPGHBtX46z0!8CabjI@Z##z9Z(lIJqp8 z&WH*^i%ZsUgd*GY^78hHtTJmuba!gC!#J{3l{$iEG%@WHTj;!9QkJ~7ij*IcTSoyL z?QFW*>|EVIc8?17LMyZAmT|+~zsme3glpS({f_ zWZ4wjU4$6oM)Hhr`a+W@!A)8Jv*{-{%Z(PousWt`5^28gXANpLIz!XBS)Yuqn?gbL zE;4Rd$F$RcEx}4l^uwl0omh4}+k|+;N$I#xO^T~z2t}hO%zc;EfTe~}-biRD;Tt;R zd{hvKNij9W>G0arlxE+iDf~=e_A#|7Oby-Dg6zjgCY1_NjQt@Fx8&l_8GMVV#K_SY z0f=cfzeg&kbD|~B&ys$cIKCd1XP0)*)NB73sx@@grhZ-nH3Z9;CxHQe0~bkaE0Sy; zc4ft`L!OF}WzKR!-;vyTW5$hT6Ag33(Vor;m@Fsq61c8lgL diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_old.py b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_old.py new file mode 100644 index 000000000..f94ffef59 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_old.py @@ -0,0 +1,1593 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import List, Optional, Tuple + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) +from torch import Tensor, nn + +from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask + + +class Conformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension, also the output dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + layer_dropout (float): layer-dropout rate. + cnn_module_kernel (int): Kernel size of convolution module + vgg_frontend (bool): whether to use vgg frontend. + dynamic_chunk_training (bool): whether to use dynamic chunk training, if + you want to train a streaming model, this is expected to be True. + When setting True, it will use a masking strategy to make the attention + see only limited left and right context. + short_chunk_threshold (float): a threshold to determinize the chunk size + to be used in masking training, if the randomly generated chunk size + is greater than ``max_len * short_chunk_threshold`` (max_len is the + max sequence length of current batch) then it will use + full context in training (i.e. with chunk size equals to max_len). + This will be used only when dynamic_chunk_training is True. + short_chunk_size (int): see docs above, if the randomly generated chunk + size equals to or less than ``max_len * short_chunk_threshold``, the + chunk size will be sampled uniformly from 1 to short_chunk_size. + This also will be used only when dynamic_chunk_training is True. + num_left_chunks (int): the left context (in chunks) attention can see, the + chunk size is decided by short_chunk_threshold and short_chunk_size. + A minus value means seeing full left context. + This also will be used only when dynamic_chunk_training is True. + causal (bool): Whether to use causal convolution in conformer encoder + layer. This MUST be True when using dynamic_chunk_training. + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 25, + num_left_chunks: int = -1, + causal: bool = False, + ) -> None: + super(Conformer, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_layers = num_encoder_layers + self.d_model = d_model + self.cnn_module_kernel = cnn_module_kernel + self.causal = causal + self.dynamic_chunk_training = dynamic_chunk_training + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + self.num_left_chunks = num_left_chunks + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + layer_dropout, + cnn_module_kernel, + causal, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self._init_state: List[torch.Tensor] = [torch.empty(0)] + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, d_model) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 + + if not is_jit_tracing(): + assert x.size(0) == lengths.max().item() + + src_key_padding_mask = make_pad_mask(lengths) + + if self.dynamic_chunk_training: + assert ( + self.causal + ), "Causal convolution is required for streaming conformer." + max_len = x.size(0) + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = chunk_size % self.short_chunk_size + 1 + + mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + else: + x = self.encoder( + x, + pos_emb, + mask=None, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x, lengths + + @torch.jit.export + def get_init_state( + self, left_context: int, device: torch.device + ) -> List[torch.Tensor]: + """Return the initial cache state of the model. + + Args: + left_context: The left context size (in frames after subsampling). + + Returns: + Return the initial state of the model, it is a list containing two + tensors, the first one is the cache for attentions which has a shape + of (num_encoder_layers, left_context, encoder_dim), the second one + is the cache of conv_modules which has a shape of + (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). + + NOTE: the returned tensors are on the given device. + """ + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + # Note: It is OK to share the init state as it is + # not going to be modified by the model + return self._init_state + + init_states: List[torch.Tensor] = [ + torch.zeros( + ( + self.encoder_layers, + left_context, + self.d_model, + ), + device=device, + ), + torch.zeros( + ( + self.encoder_layers, + self.cnn_module_kernel - 1, + self.d_model, + ), + device=device, + ), + ] + + self._init_state = init_states + + return init_states + + @torch.jit.export + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[List[Tensor]] = None, + processed_lens: Optional[Tensor] = None, + left_context: int = 64, + right_context: int = 4, + chunk_size: int = 16, + simulate_streaming: bool = False, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + processed_lens: + How many frames (after subsampling) have been processed for each sequence. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + chunk_size: + The chunk size for decoding, this will be used to simulate streaming + decoding using masking. + simulate_streaming: + If setting True, it will use a masking strategy to simulate streaming + fashion (i.e. every chunk data only see limited left context and + right context). The whole sequence is supposed to be send at a time + When using simulate_streaming. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + - decode_states, the updated states including the information + of current chunk. + """ + + # x: [N, T, C] + # Caution: We assume the subsampling factor is 4! + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 + + if not simulate_streaming: + assert states is not None + assert processed_lens is not None + assert ( + len(states) == 2 + and states[0].shape + == (self.encoder_layers, left_context, x.size(0), self.d_model) + and states[1].shape + == ( + self.encoder_layers, + self.cnn_module_kernel - 1, + x.size(0), + self.d_model, + ) + ), f"""The length of states MUST be equal to 2, and the shape of + first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, + given {states[0].shape}. the shape of second element should be + {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, + given {states[1].shape}.""" + + lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output + + src_key_padding_mask = make_pad_mask(lengths) + + processed_mask = torch.arange(left_context, device=x.device).expand( + x.size(0), left_context + ) + processed_lens = processed_lens.view(x.size(0), 1) + processed_mask = (processed_lens <= processed_mask).flip(1) + + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) + + embed = self.encoder_embed(x) + + # cut off 1 frame on each size of embed as they see the padding + # value which causes a training and decoding mismatch. + embed = embed[:, 1:-1, :] + + embed, pos_enc = self.encoder_pos(embed, left_context) + embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + + x, states = self.encoder.chunk_forward( + embed, + pos_enc, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + states=states, + left_context=left_context, + right_context=right_context, + ) # (T, B, F) + if right_context > 0: + x = x[0:-right_context, ...] + lengths -= right_context + else: + assert states is None + states = [] # just to make torch.script.jit happy + # this branch simulates streaming decoding using mask as we are + # using in training time. + src_key_padding_mask = make_pad_mask(lengths) + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + assert x.size(0) == lengths.max().item() + + num_left_chunks = -1 + if left_context >= 0: + assert left_context % chunk_size == 0 + num_left_chunks = left_context // chunk_size + + mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=num_left_chunks, + device=x.device, + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths, states + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): Whether to use causal convolution in conformer encoder + layer. This MUST be True when using dynamic_chunk_training and streaming decoding. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + causal: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + + self.d_model = d_model + + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + src_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_key_padding_mask: the mask for the src keys per batch (optional). + src_mask: the mask for the src sequence (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # multi-headed self-attention module + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + + src = src + self.dropout(src_att) + + # convolution module + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.dropout(conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src + + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[Tensor, List[Tensor]]: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + assert not self.training + assert len(states) == 2 + assert states[0].shape == (left_context, src.size(1), src.size(2)) + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # We put the attention cache this level (i.e. before linear transformation) + # to save memory consumption, when decoding in streaming fashion, the + # batch size would be thousands (for 32GB machine), if we cache key & val + # separately, it needs extra several GB memory. + # TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val + # separately) if needed. + key = torch.cat([states[0], src], dim=0) + val = key + if right_context > 0: + states[0] = key[ + -(left_context + right_context) : -right_context, ... # noqa + ] + else: + states[0] = key[-left_context:, ...] + + # multi-headed self-attention module + src_att = self.self_attn( + src, + key, + val, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + left_context=left_context, + )[0] + + src = src + self.dropout(src_att) + + # convolution module + conv, conv_cache = self.conv_module(src, states[1], right_context) + states[1] = conv_cache + + src = src + self.dropout(conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + return src, states + + +class ConformerEncoder(nn.Module): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + src_key_padding_mask: the mask for the src keys per batch (optional). + mask: the mask for the src sequence (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) + + return output + + @torch.jit.export + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + states: List[Tensor], + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + Shape: + src: (S, N, E). + pos_emb: (N, 2*(S+left_context)-1, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + assert not self.training + assert len(states) == 2 + assert states[0].shape == ( + self.num_layers, + left_context, + src.size(1), + src.size(2), + ) + assert states[1].size(0) == self.num_layers + + output = src + + for layer_index, mod in enumerate(self.layers): + cache = [states[0][layer_index], states[1][layer_index]] + output, cache = mod.chunk_forward( + output, + pos_emb, + states=cache, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + left_context=left_context, + right_context=right_context, + ) + states[0][layer_index] = cache[0] + states[1][layer_index] = cache[1] + + return output, states + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + if is_jit_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 + + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor, left_context: int = 0) -> None: + """Reset the positional encodings.""" + x_size_1 = x.size(1) + left_context + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_1 * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x, left_context) + x_size_1 = x.size(1) + left_context + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_1 + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + left_context=left_context, + ) + + def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1+left_context). + time1 means the length of query vector. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + + time2 = time1 + left_context + if not is_jit_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" + + if is_jit_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + if not is_jit_tracing(): + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + if not is_jit_tracing(): + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None and not is_jit_tracing(): + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + if not is_jit_tracing(): + assert pos_emb_bsz in (1, bsz) # actually it is 1 + + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd, left_context) + + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + + if not is_jit_tracing(): + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + + # If we are using dynamic_chunk_training and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`, at this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + + if not is_jit_tracing(): + assert list(attn_output.size()) == [ + bsz * num_heads, + tgt_len, + head_dim, + ] + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + causal (bool): Whether to use causal convolution. + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + self.causal = causal + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + self.lorder = kernel_size - 1 + padding = (kernel_size - 1) // 2 + if self.causal: + padding = 0 + + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + + def forward( + self, + x: Tensor, + cache: Optional[Tensor] = None, + right_context: int = 0, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + cache: The cache of depthwise_conv, only used in real streaming + decoding. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + src_key_padding_mask: the mask for the src keys per batch (optional). + of right context, some have more. + + Returns: + If cache is None return the output tensor (#time, batch, channels). + If cache is not None, return a tuple of Tensor, the first one is + the output tensor (#time, batch, channels), the second one is the + new cache for next chunk (#kernel_size - 1, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + if self.causal and self.lorder > 0: + if cache is None: + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + assert not self.training, "Cache should be None in training time" + assert cache.size(0) == self.lorder + x = torch.cat([cache.permute(1, 2, 0), x], dim=2) + if right_context > 0: + cache = x.permute(2, 0, 1)[ + -(self.lorder + right_context) : (-right_context), # noqa + ..., + ] + else: + cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + # torch.jit.script requires return types be the same as annotated above + if cache is None: + cache = torch.empty(0) + + return x.permute(2, 0, 1), cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + feature_dim = 50 + c = Conformer(num_features=feature_dim, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + )