From 2a54d49f969604e9ddaa9144f2624fd2d69cfd09 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Thu, 18 May 2023 16:55:27 +0900 Subject: [PATCH] from local --- .../.data2vec_audio.py.swp | Bin 4096 -> 45056 bytes .../data2vec_audio.py | 96 ++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp index 28c49fc845a8d483376fa4b888f3b144e3376762..26c8b1e144b355be06b35c9d9bdfe5dcfceb5222 100644 GIT binary patch literal 45056 zcmeI53zTGKec%fppb?Z+qk^tnyOpVep6Uk}prt3w42%vA;KHbZc2eA~y4_uLy{7K1 zneG_}HtY(RaFWeNHbzVYl^BpHU@*XLqN~wW(F8-1vnmmeWTWu`uAiYz*o_zZOP*6tsdctTkr#)%MLyHvGt+ z*%#Dn{-Dw9R(osNFh~DK5*SGU2@Ime&8HtzI%Q^BRrT?)S9mXY{u@W48A)Iyfsq78 z5*SHfB!Q6xMiLlFU?hQmMG}a5&n{g-dY@*}TsGfFWWJwi{{K7kUCEd~!~DN=i20G` zerUcAm~R`-QRaVOzQ2<(f0p^bXudBr4z2rVoBO|LzVFJI|BSi6*L=UoB-pxty19R@ z`Mx@1Zf=jhMiLlFU?hQ&1V$1VNnj*_kpxB(7)f9xfsq7868IM>0l!fyUGbb!=`}|a zedhnI{{MyNl}h)*+u;H@AGW|GjKgv8-RG7{e*yQyJ#aVtF5Cp~gezeIw!&F(6x@3Z zcHk=L!zJ(wFbick4qgo3Cq7?<``~xsCfEx-khnfdoIe70z}sL2PJ;i1qW?E=J-h`@ zhaWRI@EN!TZiWF|2)kecPKIZ|4;f_m27CeTgd5<$z;)1rZLk%d2VZBv;y(B@xEZd3 zx51lWGd#k8$3t){%tHlU3g2cR<$8EKybfLqC&IJfDEKx5DPMsv!!7U&a5Q{_0g89R zRq#gG4b$*r#wnhF2jJ6i8+;h9gZ}_mf)Cqa8=M5sgEAZgsmqJ6=Azg08?{!@S&Eju zE#A~@$(oIOgHEF>#*Id^vs8)=R=qiIwW?NY^txeCo|^EcE3xUgx7%Cww!wC9^A`0( zZftk%B>?fGAG;wIeXkjMoo?iv-|YmmY0KhZp&$4wu>-jyS6`x1Qa{ZHUNybBIWJm@ zYAt^)=ok1UR+U+;AB2Ne6juFSFX%MN<<$xAz=YT8S6q6m=gQ+Pb6&gI8S1g5DhyhS zRlni)qM#p&hdQKYW3HL^G=E2e!;$z>^P{L!6Zs?lTH6m-=8}IVyiUNXnf*a?X*mk# zcKEF@m?+uk+Hd?S>m(636_s+z4#^U!t5}nz%?@FC-M&|D^t-+8AgcBH-35Q4*=j~> zln-T3c_Zc(S6U>zA}4cRr&ncRPovv5U*(ENplOe!;#hVy%z<--+XDZr}SN+FI*blUxns!6eAb8k*X%iNjp ztT6R$`!Aj06@=!+(Um`Pr88Oi+7I>)ntgIR4E9Q8t+gq(&0Z^*@EXnb+~(;CZ=c^9 z1arIl1DF21XEmgN_qC`S`{0U$$UmK-$(YQ~PI#Net!^u3_u!R;mX3eY1r}DTonU|M zK+x}oRIX~+JV0|%nc%+qIbGp*BU26yDcXoiirzQ6%oo&N<7kN=mtYc zj*q-ZO>L?kPKGm9hEpZ$nPh96;`Kx~F7~{r>qW}}MSJfc=+pzxw5sG#&#+=3Of@|9VH5l1|3UgxG< zV-`2&bCTKaezLS>jL5o#uaY|LtPw1F!78b%N43Skj|MyfYWthfa?P~aWtY-^fN@8f z<5Wwc;~K}9a5j*aO1{~A^rSqio30?I{1v~Gmy>JqZt}MiRs5@zG%xy{5@Yhugv*yA zcXAKml&o1+P_|edM&9CFvV{hJzu#}D=A9-&)xEgj=hrgIIr6nw1P6v>C0;J zFnINJqJD>i9tp2hYc-ziQLR>Q`C&-A%}QG$!SnQ-PQBX*`o^zv-lQ`7RNRzGt?trN zAdh7EL8qd+6hyn!J)#rxJJ=Vr%GPk_`8zJ4jalq>+nzLJ-PRzY#q^r(Ubi24+ubdl zTw@hWtG-7-O&DO`_Z!fZHlM;>-U$a zqhWw*?CsjQ{haeJItO1`fmCIA^7p%Rr5kkA>WFF0>n?Ja9!l69^yv;n{UE5~a(CjM zR6c3@TL`%mB29EIYK?ey-p<|jHVp>%O(>;t{Mg>@t@Wkj@0IHnZwD*v={l`Ucso1w zs(Pw4qtNpgiCmK&b672j{{J-e&K>BcqW_<36vD5d-~Sp6Ac9>`hU4Ke^!V??AHqlB zS78sl1-5|b`JY3F|15}(FZ%ss=1b?^?j z91g(>Tm)0_T=*(_`^Vv3@LRA37sKgr8a$4!{y*SuxB@PR22@}Sj)d=_um4Z@?{Fhr z1DC@JTn3YHEIf{k{u15~yFp}j4vemm1V$1VN#I|v1a$XRf2KRSa<8R2wlZc@FkXX3 zEFI0NQg=c%2GsPkded%ncbi%MrUW$x#U6wQ!ounVMyZZu+0#rsdI3^t_A`lui#f6*GLX7%|f*GeQEi)yG zIrXW#hBmUXV%qJr*3_6z+FXqV%18QkN_C6F6i;;z#MvfJq*SNfXkE&FVn9y}8HZF| zm$WK2;i>TTfMY@c*@MV~D30REhyaOmelMIU1z0HhU^HDw!dMBO4sEE0JQfydBM_jtORW>m) zY%JRpaTmJyNdS2ZhcT~bLXf!4dF3@7pK3U08>y7UdwUVzId5&KX&qB7+=xD9w~u(u zz+0s%if%%vR2&uOW~lVTrr%oNMBtF`Urh=jDb5`qtQb$}vB^LfEBO>|>ZEn!-Qw&NrdRAblS+zkA zjfmD@i}z||s#FJ?6E>A(#F<*c+4Q?&GAv;cDpL^UWUOx=M#|38|1&xN<3{HvlD122 zUf#|uG4d-SFUTrWXOfk2wYu2q`cXMEt7RQK-(8hUI$v`oRt8E<+tTUQmKY~VjZ`J0 z6#8xECu-92C~sTcFs#ZHPAlF1j3q6qxhRxr9koPaZCu5dAROB>KgPfYH)UAec`xmy zG!G%0(oxbSI`&exXi%nU z)(iU`XOT3qCbuRK1J4;5(3l(RFz`J#;VlG_uV<80pz=56m)s?lJE~!uv54hbYq}=4 z%1Zx#AM?OJ&m6Pp|CgFh>|dkjKLB^b4R8=-{$J+(kAfrNPtfx}3GajV!i~@Z83Q;0 z9!Br~8r%ujgUt2+EW8lzMfblJmSF-ONALeKd4B$)fQMePXf<5qZI2Im2|Nj-Z7+w!AgL~2S?}k;_4>G5J zI-CYCgXh4{z?aeeuYoh+CGZ)>Q9cT9gZ~IVOu)%-H2e@5KLiiLr{E)S2mCsO(1%Mv z<%P+I(bq@3o6&la z(RvZBRmjY9tzhj>##FES8Lby#F@w@~M(aiT;XlQCkuu{qW{^`vZkLgm!EZ5~_0o$R z7yx`~3r1o+)(jG6R}-c6c{OTo*YBe7&dXwDSu9&VX;4w~5CfTnNLfoP|Fr$tye2Lt z_4(PvQpY`R+}t{o?}fA_S=73a)c^H*)`eone%f)~k`aO%w38qu3*vP8j0Q}2b>v+V zGkd88f%5Q8B|+)`qI=$tek%IE-QWM`==UFi{jd*qf~@_21ik)d_&unB2R}xi|2SL= z9caTD@FF-CzKeeUD0~L)fm`9%VGSB^30w@P!jbSyxF0?LgYXXc4QRu;uoI4l|BY^c z2fQD8Z~}Y>ef~kX4Qg;691YKbZ=%uoF&(7s215_umJ1!0m7Yya#>< zegl?a33kF`i~&3XpMnp=yWtclgX|M{g0XpI00S`53_#b z&)_}qRyYU;;0%y_JsLK2=>fX4<(&%bPTDH3R8ZT6%Gv!x3@)9O-a%9zlN~hfRSc|2 ziMQ0urF_(To+?C08V56;#V&hRT=rR`(_;fBt84WPqOy{@qn7Gy_UePaybU7gv1}*m zMy%46KYBr}EaueyvDmO5m8AWVWrBXE6R`fLQPZ2n#f)`&od$zt);G3iJI`bUO+~;B z2C{CEU3o0EIXJc)w0dK+-dGrQd-0}j8SIngaZR;4M=fuYw_Aj)&uMk{t6e{1huokZ zG4s(;=BBDs)B0|572&L>7W|MJwIfLn`MJ0_$E?5HNW}Ddt|5_936;>@s;^x5+K#_! z76;BUkQx!Yojc28T^2O9n+I6i+3i%@v6V5h+2eAtE{N#Wkf(sK;)@B00*wb1g?%Evm+N!&XuCVUqmG<8=7QBoa^^JLDTsqBgnMolj@3ILFe|ZYm{WVXDel|;7K{|u0uq>&C z@ELE$4#jbn$SVJHMMfefsVX#01zZ@eY))nIaV}(2td{!SK~K$+*V_{mI$DLl@*W{0MQt^7$JH)_wE5|Z~c7jmIZl^j=}tPD>tl;1Lq zPTTLSnK~nd!{5hY8B2l5J@b$KvAx=gsWG<56}CEw)=To%HG{IWQdL!UId&0GU$0PP zn=dO57o--mm|B_mov11erR=v57i@PH$I909q*{N;lK+ZZI_>JGO1y+?lkN8Xt96gc zqgl`H_MBG@sH_(+OpEn|gKAt5af)OOrJ69(OD4V6v`R~r6G>zXLg)=x%iWq$H;pt^ z9K>C&9+{k zLZFYhp)|15#O!0|5zV*}lEg_aRT3mD@_!RgN3vy;t5n%@5fW?COX8Y+L0d8>%h-bH>O3bc zxPlOuM%%0=4O1ECJP;A7;3PL24h-ccB`Pi6lCC%P5FQgOJVmA}*ylfTkvK3sm zu(_7NRC>0j$Q~5^SdeEOtyo=k-n_CpNk)<>TiA3i?UW_^(lcZ4c;2rC4i{+9Rr8jk zs29#oPWt`T=DupTzckrvEKW|JF~#fUPCxau)2sZAqpfv~3k%Pzps?w}5_iEx5Vi@- z9=oEv6bo``ano+e*gsi3n*Y~w(T|Jb`F~W*(jOeb2l_v~Lv{TrbNw>^|8BS*E`?2S zB0Lu!M!)|W$XS1v!y92YycoWPKL6+N`*00h4u{}X@Jjdxbo$T3`(PE$ffvFH-~n{{ z8{tye1m8xFzZsTc2`1oM=^Fj(;V1k3>#zgl{Jt~b zIQTw#`)A=p@P7DZ2;fxs9{T$y;S(VH{Y9sjx&9O3W$;3H0el@j{vNm;Zh~bv1AZ1{ zt^fO=4pVSEd<%X3WAJNm8Jr5D!#{?OelJ9@3|rwWc$|E>6IA|~yeXM1_0;PS<#}#S zb5Z{}?4=%+8$rnaGgYPKfe@X8OBCAH}#WWK0 z-19^=^x|uBi{P2ww0kDx!;d!;Dt#sGrL|prS4>riHl`Up&b8;(AN!q<(r`;DuVZy< zgZ5QoitQ>RN>F}$q-5h~zl$I9Vro4pj%3MtV840RC(`V3c4Eq5MsVru!fnd34*C@t zFJb{X$3e?Fjb}pTkS*mIg(SH^%Q=!5D;sTpC8(+L7x;C?DiYlSv12y)H9L#l^5Phl ztkNa4HMwjoOqfasiX)QPZ;rC$PQBDlJug{#O7>=4`6;{!xt0^Zb@0-C;DDu+ zB}p9zM!Je>t8_9b7kddaTxp^!25aG5{3~6yL{4#7G=4kdCztjmF0oQ|WE2kpB#jUaZ@;f$~b>j%wQVd?}f z8<2b}SGLaQS|@!JV#zwNnQ!f@&WahKiY>)bm4-cFcV=wjeu;-7=EdGAGX>`Wh>@xp zT1&FQsasB*E}sY~7APHda?Y)$%@ut~u~Fn~A&FzCq>(61cs%7YNSP+exKhB33Zu+a zcUT{CN^6h$SIIawooVT?&Uw3HB}_ZkGf~|>6;(MEA;~qajdjlAZnqmtJSTn18G>>$ zKCG*9NJ!E+vR>86MYCh5NB>b{e-t6LtIjDOS?ntnJZ?&5#03( zsH-lS@n~vE*|FA&hOCXxMMjy}sAq$*p39MRKs+C@C_S{SA65>!9mi5Ep;(0To2U28 zPq_ieybYWM9L}uTT9qBviq4j7n@Lg7(RPu{oZ)GnAgNIz5v(eHvtGg2GmmkZ|341B z@(V`)Uo<-AgXs7-!TUkZ0sK`s8|3W&r^7?&`5%J+1hU_M4kq9`==Yxlnft#5ZiegN zYIrLw!t0?7uK+v${|R*ex5HkThg0C$AmagY&i?_}4zGq|;Y*AI{8#vGH~=j;8@9tM zLH7NB4n71jJ|J`bGCpt|$Xx&ZAnX5EU=uuw&i~i&pW)3=ftSJ)==eW^JK$~bI+%dx z!{4Ig-vw90CGcMl1o&yedD%jXgL=5W)M$zkC>j}{F!nxs;mN73Hs76kP|}W=p{K(l*2GN z{>inHs*J0sBodVLMusY4U+7UPKa%ONBc~Fl?fPKTP>-1oBI(eWr;$3UGf=5EN6#{v zsw-$ooXTarQjroxS_BY3A&4OU*^Dq7q>@o(-cU{wNlRA{Q_vQf)8*ImYqI?$zfqqe29Q*75wvnE9$zsuVuyzK=OO{qkh z5=aKw^MjPJxt-25bmT&}42Vc9+~AB=rc<_6|4m*RA!91>Ir%9ox^tWuxs+ri<3lvn zLtpO8l^^&ju5ZKY+4y0syFx+dBwt=LQ7f)6CTw+T`S1#t_gvDa07Ga@9tUwIJZ&5HeqD_f!>#`lVy46A_aRy^@jQWVaag;tXA@*MtAWF7UUHd!u{@*(?jBYHpGwWMZq~1Hwx(q&((1e_IJPx3e#sIHcQBvGJzo?z z@X%qfWSs|&oY+6S>=)OYOgz)EDoRQ2A(M6@NsG&GqDCp1+if(LG z#w79~XYSm;fRj;N{7RvZDTtQ^94UMKf!B!G@mYTO`xm3~lopvgM zv*kDpN?zK)8b($e>M2S0)6GbpBu52IaIEPHi-zop0#=X`f9Gg^#c9)GL**tqI?o!` z;P!d_Qydgl$wwB0rYk3}xCnS@#ER%T&T`2HC*LdNyq@nw;TS`|_uVoyTfk&oAU8du zI{(GVRw}p)Z+PMpO{vLb<;rEf#8BOu&%`lxB&tjb=ot%l!0^zxi~j!-^vPR|{{M?c zx4a7-{}#9z4#8Vt1@^#sFasyR(?HGykaqyy2Je6?pbd+#3tkVegm0to|1roJ0hhuS zcp4l5e~GUD7w|#21v)SVkD=?o7i3Mq8k9ly{~rfmLC=@>{QVMi;duBS`u#)jr?3i_ z!MPyo0$vHvgzuy8%RBuR;VhVhZ=&nVyZn9=YOn+33;k$R@&6)v{=bJ1 zUIg;4zCCahJcus;Rk$5)gKOb+a0J|g9v_0}^$Tzch+h9A_$!b+_#h0sB>UyEnJFDD z3S(|k+_hE|6IK$d1GeLFd`h_L-jT1M9qdr_bD z^2xvnLie6nB37p9VYr0%|D{f9O8%qml_R?|LjncPVt>tBla$9MrR1_&$VgE34q~p# z!aei8rfsa~H^U{;^RtOA9Br_MafpeO@WL1KzOVHnj!T^a+f8)=szbD;SY{)fze5sS zHjAuF@{zoFD`Lnnc6o8r!LYZ&gR_5V@*oqIyznYIRPFTEHu1EU`2mS^jlCt+dT+p+ zll0(8#&_}9xF!DS7cNTGljE}GO2)p*af6r7kIWct z(xH{t)Sg5)RU{#o5pP!e5~E^4nWN$7rLG;OQi&%e`N&?~u-aas9N10kMkvE!F-IfgHFwM zt7Xl}A$s26rC&wn`AmlMYFrk7u^c(-X58bD{Tk8!uGi(&cWPl|h&H#vzLSubJu>$1 zM;>K}{c*f7m2I%e;PXVL}nWV*>UEoJ>s zOvfXWPo}e0JKUs0YPunvmT|3Qs@6bCH!qO&iGVB!WUhwEj&zWzXs5g--uZ&tVoHPK z<<$}A$yo%cXo4pgZPKvH{-xAu^0KNu)%_soqyr~1n@L!(KeI0|b(U>@Tbnpi zU19UzPUtXUWo@J-S$qb*^)+AX#l-qVSjm?Q%6q3tJrYRzI z-f|aZ%zmA?Va;4mwE~1^lrE4)iP5Y#tgvpyTd(`|t>Tl^AqIa_Xf3HzKKaXC2 zFUZ>ecfs4?5;zOS;P27r{~O2}fE(aySOU@cWi3Dzse+VCiYv61+9=?U%|4z6HdaxVD;1P8H2jFuc?+lc80c?Sn z!_n{@_!PSTpTGy;0yq)ANa8;R_B_Bgzt0BQ^DplKxD&2~9(3U~@O+TH0k^|#@SD(q zU7#qkenQn^Ww{9i3B8st@T(FZ0&aHT^-bkcp zABru+mJ-hZwti;xR2n(yn$kyT%WF&%F&Fo8NkBWQc3@Sj-0{mf5`MESe@5OTEXz$9 zBVm#}^@?rr%C@L7l9Tc<-g}%*qYFNv?#E%P*D%STHC@HojLww|$+;a8jHFO{ndU42 zZQtyvk|h~v6~$)x*rvy9CH``WIuoh4!4(BjQg$*QJM|dl?G9b>}mmXz`*b z0FsyPUal)0vkm*7H*GW741G#z8@6G7qiI9dnc{qytdwI3wr$@n+X=Vr-d>VOC$EiW z&N{WRxPWErrzf6UYP*J2+_g zgJmVu2FmU3ZQE73-M%{)%#%)GQRA-51eo^^0@5TCp45VV=wOn{BGq^GKGK$aJ5YwzNyE z<5`V#DJ<5u4y6QTJjPVT7NApoex@~kc@b0$YzN3g88Jp&9%%9ZRt%sncPc0jf>~EQUc5R zEsLt10E^Y~*n@PL7?W-*?~$AJ4kivd(LEoPQL#H@=cV1sohMtCo?JN^lQNW8EFd8o z)zK9pgqPA}6CGG56mh9blKM<%;Ree@A|h-pR{kW;-Pf6-ryk4cc3i3McF-;}I4Tj7 zp1w^&Np8rMuhHq$h@HYQfuft~^u|t2oaT}d%v(*0CvR#ao;E>1Nl$HsNUraYSja{G jiPc25Sud@t)XuBKFilChHy=Z0eyv~CI-?2!)#U#ICI6O^ delta 16 YcmZp8z|^3y;U)9L2dtalF$?hk06<3u*8l(j diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py index 0f857723f..0e8b4a95f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py @@ -133,6 +133,102 @@ class TransformerEncoderAdapter(TransformerEncoder): return x, layer_results +class TransformerEncoderAdapter(TransformerEncoder): + def __init__(self, args: Wav2Vec2Config): + super().__init__(args) + self.adapters = ResidualAdapterModule(proj_dim=512) + + for p in self.adapters.parameters(): + p.data /= 10. + #p.data = nn.Parameter(torch.zeros(p.size()).to('cuda')) + #p.data = nn.Parameter(torch.randn(p.size()).to('cuda')/20.) + + def forward(self, x, padding_mask=None, layer=None, tgt_layer=None): + x, layer_results = self.extract_features_with_adapter( + x, + padding_mask=padding_mask, + tgt_layer=tgt_layer + ) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features_with_adapter( + self, + x, + padding_mask=None, + tgt_layer=None, + min_layer=0, + ): + + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + # pad to the sequence length dimension + x, pad_length = pad_to_multiple( + x, self.required_seq_len_multiple, dim=-2, value=0 + ) + if pad_length > 0 and padding_mask is None: + padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) + padding_mask[:, -pad_length:] = True + else: + padding_mask, _ = pad_to_multiple( + padding_mask, self.required_seq_len_multiple, dim=-1, value=True + ) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + r = None + + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() if self.layerdrop > 0 else 1 + if not self.training or (dropout_probability > self.layerdrop): + x, (z, lr) = layer( + x, self_attn_padding_mask=padding_mask, need_weights=False, + ) + x = self.adapters(x, layer_id=i) + + if i >= min_layer: + layer_results.append((x, z, lr)) + + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + # undo paddding + if pad_length > 0: + x = x[:, :-pad_length] + + def undo_pad(a, b, c): + return ( + a[:-pad_length], + b[:-pad_length] if b is not None else b, + c[:-pad_length], + ) + + layer_results = [undo_pad(*u) for u in layer_results] + + return x, layer_results + + class ResidualAdapterModule(nn.Module): """ Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf