add slerp in qwen
This commit is contained in:
parent
793508dbd0
commit
4c3b90457d
@ -33,8 +33,8 @@ def main():
|
|||||||
file_path = os.path.dirname(__file__)
|
file_path = os.path.dirname(__file__)
|
||||||
|
|
||||||
base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
|
base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
|
||||||
peft_model_path = file_path + "/output/v23-20251214-111804/checkpoint-3632"
|
peft_model_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"
|
||||||
save_path = file_path + "/output/v23-20251214-111804/merged_checkpoint-3632"
|
save_path = file_path + "/output/v23-20251214-111804/merged_checkpoint-slerp"
|
||||||
merge(base_model_path, peft_model_path, save_path)
|
merge(base_model_path, peft_model_path, save_path)
|
||||||
|
|
||||||
items = ["1_Pooling", "config_sentence_transformers.json", "merges.txt", "modules.json", "README.md", "tokenizer_config.json", "tokenizer.json",
|
items = ["1_Pooling", "config_sentence_transformers.json", "merges.txt", "modules.json", "README.md", "tokenizer_config.json", "tokenizer.json",
|
||||||
|
|||||||
114
train/qwen/slerp_merge.py
Normal file
114
train/qwen/slerp_merge.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from peft import PeftModel
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||||
|
import shutil
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
os.environ["PEFT_BOFT_FORCE_CPU"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def slerp(t, v0, v1, eps=1e-8):
|
||||||
|
v0_norm = v0 / (v0.norm() + eps)
|
||||||
|
v1_norm = v1 / (v1.norm() + eps)
|
||||||
|
|
||||||
|
dot = (v0_norm * v1_norm).sum()
|
||||||
|
dot = torch.clamp(dot, -1.0, 1.0)
|
||||||
|
|
||||||
|
if dot > 0.9995:
|
||||||
|
# fallback to linear interpolation
|
||||||
|
return (1 - t) * v0 + t * v1
|
||||||
|
|
||||||
|
theta_0 = torch.acos(dot)
|
||||||
|
sin_theta_0 = torch.sin(theta_0)
|
||||||
|
|
||||||
|
theta_t = theta_0 * t
|
||||||
|
sin_theta_t = torch.sin(theta_t)
|
||||||
|
|
||||||
|
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = sin_theta_t / sin_theta_0
|
||||||
|
|
||||||
|
return s0 * v0 + s1 * v1
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(path):
|
||||||
|
return load_file(path)
|
||||||
|
|
||||||
|
def merge_loras_slerp(lora_paths, weights=None):
|
||||||
|
loras = [load_lora(p) for p in lora_paths]
|
||||||
|
|
||||||
|
if weights is None:
|
||||||
|
weights = [1 / len(loras)] * len(loras)
|
||||||
|
|
||||||
|
merged = {}
|
||||||
|
|
||||||
|
keys = loras[0].keys()
|
||||||
|
|
||||||
|
for k in keys:
|
||||||
|
tensors = [l[k].float() for l in loras]
|
||||||
|
|
||||||
|
# iterative slerp
|
||||||
|
out = tensors[0]
|
||||||
|
acc_weight = weights[0]
|
||||||
|
|
||||||
|
for i in range(1, len(tensors)):
|
||||||
|
t = weights[i] / (acc_weight + weights[i])
|
||||||
|
out = slerp(t, out, tensors[i])
|
||||||
|
acc_weight += weights[i]
|
||||||
|
|
||||||
|
merged[k] = out.to(tensors[0].dtype)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def merge(base_model_path, peft_model_path, save_path):
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="bfloat16")
|
||||||
|
ft_model = PeftModel.from_pretrained(base_model, peft_model_path)
|
||||||
|
ft_model = ft_model.merge_and_unload()
|
||||||
|
ft_model.save_pretrained(save_path)
|
||||||
|
|
||||||
|
def copy_selected_items(src_path, dst_path, items):
|
||||||
|
os.makedirs(dst_path, exist_ok=True)
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
source_item = os.path.join(src_path, item)
|
||||||
|
dest_item = os.path.join(dst_path, item)
|
||||||
|
|
||||||
|
if not os.path.exists(source_item):
|
||||||
|
print(f"⚠ {item} در مسیر مبدا پیدا نشد!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.isdir(source_item):
|
||||||
|
shutil.copytree(source_item, dest_item, dirs_exist_ok=True)
|
||||||
|
elif os.path.isfile(source_item):
|
||||||
|
shutil.copy2(source_item, dest_item)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
file_path = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
|
||||||
|
peft_model_path = [
|
||||||
|
file_path + "/output/v23-20251214-111804/checkpoint-3632",
|
||||||
|
file_path + "/output/v23-20251214-111804/checkpoint-3000",
|
||||||
|
file_path + "/output/v23-20251214-111804/checkpoint-2000",
|
||||||
|
]
|
||||||
|
save_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"
|
||||||
|
|
||||||
|
|
||||||
|
merged_lora = merge_loras_slerp([peft_model_path[i] + "/adapter_model.safetensors" for i in range(len(peft_model_path))])
|
||||||
|
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
|
save_file(merged_lora, save_path + "/adapter_model.safetensors")
|
||||||
|
# merge(base_model_path, peft_model_path, save_path)
|
||||||
|
|
||||||
|
items = ["adapter_config.json", "additional_config.json"]
|
||||||
|
copy_selected_items(peft_model_path[0], save_path, items)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user