diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp index 914514cac..3b7f923d9 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py index d83874c08..2260aff8a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py @@ -1591,17 +1591,18 @@ def run_adapter(rank, world_size, args, wb=None): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - lora_module = [] - for i, module in enumerate(model.modules()): + lora_modules = [] + for modules in model.modules(): if isinstance(module, fairseq.modules.multihead_attention.MultiheadAttention): - for m in module.modules(): - lora_module.append(LoRAHook(m)) + for module in modules.modules(): + lora_modules.append(LoRAHook(m)) adapter_names = [] adapter_param = [] + for lora in lora_module + ''' for n, p in model.named_parameters(): print(n) - ''' if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n: adapter_names.append(n) adapter_param.append(p) @@ -1609,7 +1610,7 @@ def run_adapter(rank, world_size, args, wb=None): p.requires_grad = True else: p.requires_grad = False - ''' + ''' optimizer_adapter = ScaledAdam( adapter_param, lr=params.adapter_lr,