24 lines
858 B
Python
24 lines
858 B
Python
import json
|
|
import numpy as np
|
|
import os
|
|
from peft import PeftModel
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
|
|
def merge(base_model_path, peft_model_path, save_path):
|
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype="bfloat16")
|
|
ft_model = PeftModel.from_pretrained(base_model, peft_model_path)
|
|
ft_model = ft_model.merge_and_unload()
|
|
ft_model.save_pretrained(save_path)
|
|
|
|
def main():
|
|
file_path = os.path.dirname(__file__)
|
|
|
|
base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model"
|
|
peft_model_path = file_path + "/output/v1-20251122-184545/checkpoint-3434"
|
|
save_path = file_path + "/output/v1-20251122-184545/merged_checkpoint-3434"
|
|
merge(base_model_path, peft_model_path, save_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |