from local

This commit is contained in:
dohe0342 2023-05-24 21:18:23 +09:00
parent 9b45d8bd93
commit e0858fb471
2 changed files with 11 additions and 0 deletions

View File

@ -687,6 +687,17 @@ def main():
load_checkpoint(f"{params.exp_dir}/{params.model_name}", model) load_checkpoint(f"{params.exp_dir}/{params.model_name}", model)
elif 'lora' in params.model_name: elif 'lora' in params.model_name:
load_checkpoint(f"{params.exp_dir}/../d2v-base-T.pt", model) load_checkpoint(f"{params.exp_dir}/../d2v-base-T.pt", model)
## for lora hooking
lora_modules = []
for modules in model.modules():
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
for module in modules.modules():
if isinstance(module, torch.nn.Linear):
lora_modules.append(LoRAHook(module))
for i, lora in enuemrate(lora_modules):
else: else:
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0: