From 594e74d59472b3c99dcf6ae8775b9c910ad6aa23 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Fri, 26 May 2023 11:51:30 +0900 Subject: [PATCH] from local --- .../.decode_new.py.swp | Bin 45056 -> 49152 bytes .../decode_new.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.decode_new.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.decode_new.py.swp index 0889478866126d60e42bc276b3189b5cce0956e2..3014e32d73bd9005315d95ac0c940a83beec6013 100644 GIT binary patch delta 1467 zcmb8vTWl0n7{KwbH<8x1+t_Z^pgB;nTi6a2?*i5gR!7~`e3HHMJ336UBl@IZ~c5HO}*9)uSPp(TWS0ug=jfA&%}^5RK;oynO! z=X~eu`A+t9X*OHd(B0OxzEeGAv{ zA7cPt;(fe`$1zJhN3gw0GO~%6CY0lNrN|^qtiWP?y_j9$EwrHux9${~!318zYAnIr zB9UqAKonuj-yw1s!{|gQu2zWrh)2)}b0FOGKtHjCb)4(pYGU*6&$%-0;1rD`Ph*gVFJZnq_e>X{v;oH2YI3uDc%j zcb8{c%JN)$S&QfyicJqQa({d;Vg)1PVdMXe3*H;QXZ_!^#91en@=f24Sx%21(O+V@ zu>QMlzF)0XzL)IjBMQ?B20yLu$c@>V-;3e+c5+v2r?bmQ`2UE#*GnoVs`O8l<0{)t z_uENU%`ZrP(s4W0^0WlA!}sPk6f5)@W+Gv`R;0swzWq7R+aAgF2cwhqL3UO7Lr$-9 zQjW_G+#Wms&5jlVbX}vvacwi%5=?&*&g3W5s~4v(JR_`I%&&glZq8Ix^_osHW$!Q+ zMyKPx(`S=E)Aw!HcRbhVEcUe*yB6l&q?ArVVRRu3PQ`74$R)J+!?o&_R~N!-YRXNL zOJg6|aNMmz&L&ny3}cnGdZ4x2Oxdk3=-fy1r!N0*wBu~5Sq}?_MwSHUCYpn@`|r<~ zmdMYX)Ri1$eI;xtoZ6d|{|QW@3y+}*e{v1{iZN`)v#7^qu7MG}jCH8SMN0o8q_GsW z_>HPRgh6!c>!6XB)0F)YY{O=(!5n4(9rojGm}td$%Kj&Oh5i4nf>c0RYl~|hq delta 425 zcmXBQPbh-{9LDkIwfQqI+i06O=;cHqxoAlow}Y(w%SDlL^rxiduqjgHu-=PSgyi5r zwgcrX7bQ*><*jkEP!5#jD^Gp;J%{J?OD0`OrzkklJ{;|{!fov$erdgaW^()_SnwGh z+vP79G5g(LlWnTH^QIgtW4V)wRfc)ve;woe<;*7?VHuO~qTrS8k;Mi^5x`HSbcF+q zpc>y5(gpU>k5=SN=@Kc-A&fdx30;>rBWKp z7(z2loR>&57(+Kg$Q4U_m_`I0#RHN>@#4WP9bySxcymd&SV0t>2;sFz%3u)@G{D3c iU;ISgzL^cCoOpJA*D&qgbc^krzGPRno($VLwtfMlSxpiE diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode_new.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode_new.py index 87f3d813e..671edd23c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode_new.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/decode_new.py @@ -666,6 +666,22 @@ def main(): if '.pt' in params.model_name: load_checkpoint(f"{params.exp_dir}/{params.model_name}", model) + elif 'lora' in params.model_name: + load_checkpoint(f"{params.exp_dir}/../d2v-base-T.pt", model) + + ## for lora hooking + lora_modules = [] + for modules in model.modules(): + if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention): + for module in modules.modules(): + if isinstance(module, torch.nn.Linear): + lora_modules.append(LoRAHook(module)) + + for i, lora in enumerate(lora_modules): + lora_param = torch.load(f"{params.exp_dir}/lora_{params.iter}_{i}.pt") + lora.lora.load_state_dict(lora_param) + lora.lora.to(device) + logging.info("lora params load done") else: if not params.use_averaged_model: if params.iter > 0: