embedding_model/train/qwen/slerp_merge.py
2025-12-27 06:49:38 +00:00

114 lines
3.1 KiB
Python

import json
import numpy as np
import os
from peft import PeftModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import shutil
from safetensors.torch import load_file
from safetensors.torch import save_file
os.environ["PEFT_BOFT_FORCE_CPU"] = "1"
def slerp(t, v0, v1, eps=1e-8):
v0_norm = v0 / (v0.norm() + eps)
v1_norm = v1 / (v1.norm() + eps)
dot = (v0_norm * v1_norm).sum()
dot = torch.clamp(dot, -1.0, 1.0)
if dot > 0.9995:
# fallback to linear interpolation
return (1 - t) * v0 + t * v1
theta_0 = torch.acos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = torch.sin(theta_t)
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
return s0 * v0 + s1 * v1
def load_lora(path):
return load_file(path)
def merge_loras_slerp(lora_paths, weights=None):
loras = [load_lora(p) for p in lora_paths]
if weights is None:
weights = [1 / len(loras)] * len(loras)
merged = {}
keys = loras[0].keys()
for k in keys:
tensors = [l[k].float() for l in loras]
# iterative slerp
out = tensors[0]
acc_weight = weights[0]
for i in range(1, len(tensors)):
t = weights[i] / (acc_weight + weights[i])
out = slerp(t, out, tensors[i])
acc_weight += weights[i]
merged[k] = out.to(tensors[0].dtype)
return merged
def merge(base_model_path, peft_model_path, save_path):
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="bfloat16")
ft_model = PeftModel.from_pretrained(base_model, peft_model_path)
ft_model = ft_model.merge_and_unload()
ft_model.save_pretrained(save_path)
def copy_selected_items(src_path, dst_path, items):
os.makedirs(dst_path, exist_ok=True)
for item in items:
source_item = os.path.join(src_path, item)
dest_item = os.path.join(dst_path, item)
if not os.path.exists(source_item):
print(f"{item} در مسیر مبدا پیدا نشد!")
continue
if os.path.isdir(source_item):
shutil.copytree(source_item, dest_item, dirs_exist_ok=True)
elif os.path.isfile(source_item):
shutil.copy2(source_item, dest_item)
def main():
file_path = os.path.dirname(__file__)
# base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
peft_model_path = [
file_path + "/output/v23-20251214-111804/checkpoint-3632",
file_path + "/output/v23-20251214-111804/checkpoint-3000",
]
save_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"
merged_lora = merge_loras_slerp([peft_model_path[i] + "/adapter_model.safetensors" for i in range(len(peft_model_path))])
os.makedirs(save_path, exist_ok=True)
save_file(merged_lora, save_path + "/adapter_model.safetensors")
# merge(base_model_path, peft_model_path, save_path)
items = ["adapter_config.json", "additional_config.json"]
copy_selected_items(peft_model_path[0], save_path, items)
if __name__ == "__main__":
main()