embedding_model/train/qwen/merge_model.py
2025-11-20 06:48:29 +00:00

24 lines
858 B
Python

import json
import numpy as np
import os
from peft import PeftModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
def merge(base_model_path, peft_model_path, save_path):
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="bfloat16")
ft_model = PeftModel.from_pretrained(base_model, peft_model_path)
ft_model = ft_model.merge_and_unload()
ft_model.save_pretrained(save_path)
def main():
file_path = os.path.dirname(__file__)
base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
peft_model_path = file_path + "/output/v0-20251118-115015/checkpoint-3434"
save_path = file_path + "/output/v0-20251118-115015/merged_checkpoint-3434"
merge(base_model_path, peft_model_path, save_path)
if __name__ == "__main__":
main()