114 lines
3.1 KiB
Python
114 lines
3.1 KiB
Python
import json
|
|
import numpy as np
|
|
import os
|
|
from peft import PeftModel
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
import shutil
|
|
from safetensors.torch import load_file
|
|
from safetensors.torch import save_file
|
|
|
|
os.environ["PEFT_BOFT_FORCE_CPU"] = "1"
|
|
|
|
|
|
|
|
def slerp(t, v0, v1, eps=1e-8):
|
|
v0_norm = v0 / (v0.norm() + eps)
|
|
v1_norm = v1 / (v1.norm() + eps)
|
|
|
|
dot = (v0_norm * v1_norm).sum()
|
|
dot = torch.clamp(dot, -1.0, 1.0)
|
|
|
|
if dot > 0.9995:
|
|
# fallback to linear interpolation
|
|
return (1 - t) * v0 + t * v1
|
|
|
|
theta_0 = torch.acos(dot)
|
|
sin_theta_0 = torch.sin(theta_0)
|
|
|
|
theta_t = theta_0 * t
|
|
sin_theta_t = torch.sin(theta_t)
|
|
|
|
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
|
s1 = sin_theta_t / sin_theta_0
|
|
|
|
return s0 * v0 + s1 * v1
|
|
|
|
|
|
def load_lora(path):
|
|
return load_file(path)
|
|
|
|
def merge_loras_slerp(lora_paths, weights=None):
|
|
loras = [load_lora(p) for p in lora_paths]
|
|
|
|
if weights is None:
|
|
weights = [1 / len(loras)] * len(loras)
|
|
|
|
merged = {}
|
|
|
|
keys = loras[0].keys()
|
|
|
|
for k in keys:
|
|
tensors = [l[k].float() for l in loras]
|
|
|
|
# iterative slerp
|
|
out = tensors[0]
|
|
acc_weight = weights[0]
|
|
|
|
for i in range(1, len(tensors)):
|
|
t = weights[i] / (acc_weight + weights[i])
|
|
out = slerp(t, out, tensors[i])
|
|
acc_weight += weights[i]
|
|
|
|
merged[k] = out.to(tensors[0].dtype)
|
|
|
|
return merged
|
|
|
|
|
|
def merge(base_model_path, peft_model_path, save_path):
|
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="bfloat16")
|
|
ft_model = PeftModel.from_pretrained(base_model, peft_model_path)
|
|
ft_model = ft_model.merge_and_unload()
|
|
ft_model.save_pretrained(save_path)
|
|
|
|
def copy_selected_items(src_path, dst_path, items):
|
|
os.makedirs(dst_path, exist_ok=True)
|
|
|
|
for item in items:
|
|
source_item = os.path.join(src_path, item)
|
|
dest_item = os.path.join(dst_path, item)
|
|
|
|
if not os.path.exists(source_item):
|
|
print(f"⚠ {item} در مسیر مبدا پیدا نشد!")
|
|
continue
|
|
|
|
if os.path.isdir(source_item):
|
|
shutil.copytree(source_item, dest_item, dirs_exist_ok=True)
|
|
elif os.path.isfile(source_item):
|
|
shutil.copy2(source_item, dest_item)
|
|
|
|
def main():
|
|
file_path = os.path.dirname(__file__)
|
|
|
|
# base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
|
|
peft_model_path = [
|
|
file_path + "/output/v23-20251214-111804/checkpoint-3632",
|
|
file_path + "/output/v23-20251214-111804/checkpoint-3000",
|
|
]
|
|
save_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"
|
|
|
|
|
|
merged_lora = merge_loras_slerp([peft_model_path[i] + "/adapter_model.safetensors" for i in range(len(peft_model_path))])
|
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
save_file(merged_lora, save_path + "/adapter_model.safetensors")
|
|
# merge(base_model_path, peft_model_path, save_path)
|
|
|
|
items = ["adapter_config.json", "additional_config.json"]
|
|
copy_selected_items(peft_model_path[0], save_path, items)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|