embedding_model/train/qwen/merge_model.py
2025-12-07 06:56:10 +00:00

47 lines
1.7 KiB
Python

import json
import numpy as np
import os
from peft import PeftModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import shutil
os.environ["PEFT_BOFT_FORCE_CPU"] = "1"
def merge(base_model_path, peft_model_path, save_path):
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="bfloat16")
ft_model = PeftModel.from_pretrained(base_model, peft_model_path)
ft_model = ft_model.merge_and_unload()
ft_model.save_pretrained(save_path)
def copy_selected_items(src_path, dst_path, items):
os.makedirs(dst_path, exist_ok=True)
for item in items:
source_item = os.path.join(src_path, item)
dest_item = os.path.join(dst_path, item)
if not os.path.exists(source_item):
print(f"{item} در مسیر مبدا پیدا نشد!")
continue
if os.path.isdir(source_item):
shutil.copytree(source_item, dest_item, dirs_exist_ok=True)
elif os.path.isfile(source_item):
shutil.copy2(source_item, dest_item)
def main():
file_path = os.path.dirname(__file__)
base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
peft_model_path = file_path + "/output/v17-20251202-223944/checkpoint-387"
save_path = file_path + "/output/v17-20251202-223944/merged_checkpoint-387"
merge(base_model_path, peft_model_path, save_path)
items = ["1_Pooling", "config_sentence_transformers.json", "merges.txt", "modules.json", "README.md", "tokenizer_config.json", "tokenizer.json",
"vocab.json"]
copy_selected_items(base_model_path, save_path, items)
if __name__ == "__main__":
main()