from local

This commit is contained in:
dohe0342 2023-05-26 11:51:30 +09:00
parent 45dbc89145
commit 594e74d594
2 changed files with 16 additions and 0 deletions

View File

@ -666,6 +666,22 @@ def main():
if '.pt' in params.model_name:
load_checkpoint(f"{params.exp_dir}/{params.model_name}", model)
elif 'lora' in params.model_name:
load_checkpoint(f"{params.exp_dir}/../d2v-base-T.pt", model)
## for lora hooking
lora_modules = []
for modules in model.modules():
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
for module in modules.modules():
if isinstance(module, torch.nn.Linear):
lora_modules.append(LoRAHook(module))
for i, lora in enumerate(lora_modules):
lora_param = torch.load(f"{params.exp_dir}/lora_{params.iter}_{i}.pt")
lora.lora.load_state_dict(lora_param)
lora.lora.to(device)
logging.info("lora params load done")
else:
if not params.use_averaged_model:
if params.iter > 0: