From 177bce2db41d6289db7117af8d763eb49cf0fb90 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Wed, 24 May 2023 13:37:08 +0900 Subject: [PATCH] from local --- .../.train_lora.py.swp | Bin 90112 -> 90112 bytes .../train_lora.py | 13 +++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp index 914514cac6bdf6ef1e8279e9af706c553bf711da..3b7f923d95df056458cb198f893361111d218b57 100644 GIT binary patch delta 292 zcmZoTz}j$tRV>LM%+puFQqO<^2m}}y{!Y$KkuumQ_DhhFyID|ZyPy!`XI|D|prrS9 zP6fv6jDjl6Kn@VEWMW{b0b=jz9~BvcCx6fr=gG-0N{r9VPbtkwEuNmJ#Hc0!W-Dao zDZoT0DlulWLphUs;w1%<^(fdXM8|I5&CXcIw0*lOqc0b0IX?r#;^~|QjM9vYx2qa3 z-sk5H;bUN!1H>TLa!(g@U=)_y2voHai0gnj7l`G7_%~3=b+FRyjt-1be3SQ0LM%+puFQqO<^2m}}yK1|L{`KZ59?3W-TSF@ndc0nP=&%CUDKuNdl zoC=KB83k3D85q7WF)+*p;xZt1oBmOeF_ zH4%tgf!G{~g@O1zNZEEr2gYc=i4SBZ@0-XmeTg6A#p%_dj6Bmz{TVH`U-W0x7XSb! CuuXmd diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py index d83874c08..2260aff8a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py @@ -1591,17 +1591,18 @@ def run_adapter(rank, world_size, args, wb=None): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - lora_module = [] - for i, module in enumerate(model.modules()): + lora_modules = [] + for modules in model.modules(): if isinstance(module, fairseq.modules.multihead_attention.MultiheadAttention): - for m in module.modules(): - lora_module.append(LoRAHook(m)) + for module in modules.modules(): + lora_modules.append(LoRAHook(m)) adapter_names = [] adapter_param = [] + for lora in lora_module + ''' for n, p in model.named_parameters(): print(n) - ''' if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n: adapter_names.append(n) adapter_param.append(p) @@ -1609,7 +1610,7 @@ def run_adapter(rank, world_size, args, wb=None): p.requires_grad = True else: p.requires_grad = False - ''' + ''' optimizer_adapter = ScaledAdam( adapter_param, lr=params.adapter_lr,