add slerp in qwen

This commit is contained in:
a.hediehloo 2025-12-21 16:51:16 +00:00
parent 793508dbd0
commit 4c3b90457d
2 changed files with 116 additions and 2 deletions

View File

@ -33,8 +33,8 @@ def main():
file_path = os.path.dirname(__file__)
base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
peft_model_path = file_path + "/output/v23-20251214-111804/checkpoint-3632"
save_path = file_path + "/output/v23-20251214-111804/merged_checkpoint-3632"
peft_model_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"
save_path = file_path + "/output/v23-20251214-111804/merged_checkpoint-slerp"
merge(base_model_path, peft_model_path, save_path)
items = ["1_Pooling", "config_sentence_transformers.json", "merges.txt", "modules.json", "README.md", "tokenizer_config.json", "tokenizer.json",

114
train/qwen/slerp_merge.py Normal file
View File

@ -0,0 +1,114 @@
import json
import numpy as np
import os
from peft import PeftModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import shutil
from safetensors.torch import load_file
from safetensors.torch import save_file
os.environ["PEFT_BOFT_FORCE_CPU"] = "1"
def slerp(t, v0, v1, eps=1e-8):
v0_norm = v0 / (v0.norm() + eps)
v1_norm = v1 / (v1.norm() + eps)
dot = (v0_norm * v1_norm).sum()
dot = torch.clamp(dot, -1.0, 1.0)
if dot > 0.9995:
# fallback to linear interpolation
return (1 - t) * v0 + t * v1
theta_0 = torch.acos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = torch.sin(theta_t)
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
return s0 * v0 + s1 * v1
def load_lora(path):
return load_file(path)
def merge_loras_slerp(lora_paths, weights=None):
loras = [load_lora(p) for p in lora_paths]
if weights is None:
weights = [1 / len(loras)] * len(loras)
merged = {}
keys = loras[0].keys()
for k in keys:
tensors = [l[k].float() for l in loras]
# iterative slerp
out = tensors[0]
acc_weight = weights[0]
for i in range(1, len(tensors)):
t = weights[i] / (acc_weight + weights[i])
out = slerp(t, out, tensors[i])
acc_weight += weights[i]
merged[k] = out.to(tensors[0].dtype)
return merged
def merge(base_model_path, peft_model_path, save_path):
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="bfloat16")
ft_model = PeftModel.from_pretrained(base_model, peft_model_path)
ft_model = ft_model.merge_and_unload()
ft_model.save_pretrained(save_path)
def copy_selected_items(src_path, dst_path, items):
os.makedirs(dst_path, exist_ok=True)
for item in items:
source_item = os.path.join(src_path, item)
dest_item = os.path.join(dst_path, item)
if not os.path.exists(source_item):
print(f"{item} در مسیر مبدا پیدا نشد!")
continue
if os.path.isdir(source_item):
shutil.copytree(source_item, dest_item, dirs_exist_ok=True)
elif os.path.isfile(source_item):
shutil.copy2(source_item, dest_item)
def main():
file_path = os.path.dirname(__file__)
# base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
peft_model_path = [
file_path + "/output/v23-20251214-111804/checkpoint-3632",
file_path + "/output/v23-20251214-111804/checkpoint-3000",
file_path + "/output/v23-20251214-111804/checkpoint-2000",
]
save_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"
merged_lora = merge_loras_slerp([peft_model_path[i] + "/adapter_model.safetensors" for i in range(len(peft_model_path))])
os.makedirs(save_path, exist_ok=True)
save_file(merged_lora, save_path + "/adapter_model.safetensors")
# merge(base_model_path, peft_model_path, save_path)
items = ["adapter_config.json", "additional_config.json"]
copy_selected_items(peft_model_path[0], save_path, items)
if __name__ == "__main__":
main()