add slerp in qwen
This commit is contained in:
parent
793508dbd0
commit
4c3b90457d
@ -33,8 +33,8 @@ def main():
|
||||
file_path = os.path.dirname(__file__)
|
||||
|
||||
base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
|
||||
peft_model_path = file_path + "/output/v23-20251214-111804/checkpoint-3632"
|
||||
save_path = file_path + "/output/v23-20251214-111804/merged_checkpoint-3632"
|
||||
peft_model_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"
|
||||
save_path = file_path + "/output/v23-20251214-111804/merged_checkpoint-slerp"
|
||||
merge(base_model_path, peft_model_path, save_path)
|
||||
|
||||
items = ["1_Pooling", "config_sentence_transformers.json", "merges.txt", "modules.json", "README.md", "tokenizer_config.json", "tokenizer.json",
|
||||
|
||||
114
train/qwen/slerp_merge.py
Normal file
114
train/qwen/slerp_merge.py
Normal file
@ -0,0 +1,114 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
from peft import PeftModel
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
import shutil
|
||||
from safetensors.torch import load_file
|
||||
from safetensors.torch import save_file
|
||||
|
||||
os.environ["PEFT_BOFT_FORCE_CPU"] = "1"
|
||||
|
||||
|
||||
|
||||
def slerp(t, v0, v1, eps=1e-8):
|
||||
v0_norm = v0 / (v0.norm() + eps)
|
||||
v1_norm = v1 / (v1.norm() + eps)
|
||||
|
||||
dot = (v0_norm * v1_norm).sum()
|
||||
dot = torch.clamp(dot, -1.0, 1.0)
|
||||
|
||||
if dot > 0.9995:
|
||||
# fallback to linear interpolation
|
||||
return (1 - t) * v0 + t * v1
|
||||
|
||||
theta_0 = torch.acos(dot)
|
||||
sin_theta_0 = torch.sin(theta_0)
|
||||
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = torch.sin(theta_t)
|
||||
|
||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
|
||||
return s0 * v0 + s1 * v1
|
||||
|
||||
|
||||
def load_lora(path):
|
||||
return load_file(path)
|
||||
|
||||
def merge_loras_slerp(lora_paths, weights=None):
|
||||
loras = [load_lora(p) for p in lora_paths]
|
||||
|
||||
if weights is None:
|
||||
weights = [1 / len(loras)] * len(loras)
|
||||
|
||||
merged = {}
|
||||
|
||||
keys = loras[0].keys()
|
||||
|
||||
for k in keys:
|
||||
tensors = [l[k].float() for l in loras]
|
||||
|
||||
# iterative slerp
|
||||
out = tensors[0]
|
||||
acc_weight = weights[0]
|
||||
|
||||
for i in range(1, len(tensors)):
|
||||
t = weights[i] / (acc_weight + weights[i])
|
||||
out = slerp(t, out, tensors[i])
|
||||
acc_weight += weights[i]
|
||||
|
||||
merged[k] = out.to(tensors[0].dtype)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def merge(base_model_path, peft_model_path, save_path):
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="bfloat16")
|
||||
ft_model = PeftModel.from_pretrained(base_model, peft_model_path)
|
||||
ft_model = ft_model.merge_and_unload()
|
||||
ft_model.save_pretrained(save_path)
|
||||
|
||||
def copy_selected_items(src_path, dst_path, items):
|
||||
os.makedirs(dst_path, exist_ok=True)
|
||||
|
||||
for item in items:
|
||||
source_item = os.path.join(src_path, item)
|
||||
dest_item = os.path.join(dst_path, item)
|
||||
|
||||
if not os.path.exists(source_item):
|
||||
print(f"⚠ {item} در مسیر مبدا پیدا نشد!")
|
||||
continue
|
||||
|
||||
if os.path.isdir(source_item):
|
||||
shutil.copytree(source_item, dest_item, dirs_exist_ok=True)
|
||||
elif os.path.isfile(source_item):
|
||||
shutil.copy2(source_item, dest_item)
|
||||
|
||||
def main():
|
||||
file_path = os.path.dirname(__file__)
|
||||
|
||||
# base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
|
||||
peft_model_path = [
|
||||
file_path + "/output/v23-20251214-111804/checkpoint-3632",
|
||||
file_path + "/output/v23-20251214-111804/checkpoint-3000",
|
||||
file_path + "/output/v23-20251214-111804/checkpoint-2000",
|
||||
]
|
||||
save_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"
|
||||
|
||||
|
||||
merged_lora = merge_loras_slerp([peft_model_path[i] + "/adapter_model.safetensors" for i in range(len(peft_model_path))])
|
||||
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
save_file(merged_lora, save_path + "/adapter_model.safetensors")
|
||||
# merge(base_model_path, peft_model_path, save_path)
|
||||
|
||||
items = ["adapter_config.json", "additional_config.json"]
|
||||
copy_selected_items(peft_model_path[0], save_path, items)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user