102 lines
2.9 KiB
Python
102 lines
2.9 KiB
Python
from sentence_transformers import (
|
|
SentenceTransformer,
|
|
InputExample,
|
|
SentenceTransformerTrainingArguments,
|
|
SentenceTransformerTrainer,
|
|
)
|
|
from sentence_transformers.losses import TripletLoss, MatryoshkaLoss, TripletDistanceMetric
|
|
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction, SequentialEvaluator
|
|
from transformers import EarlyStoppingCallback
|
|
import torch
|
|
import os
|
|
import json
|
|
from datasets import Dataset
|
|
|
|
model = SentenceTransformer("jinaai/jina-embeddings-v3",
|
|
trust_remote_code=True,
|
|
local_files_only=False,
|
|
model_kwargs={'default_task': 'text-matching'})
|
|
|
|
print("loading dataset")
|
|
# 3. Load a dataset to finetune on
|
|
with open("/home/firouzi/embedding_model/data/train_100.json", "r", encoding="utf-8") as f:
|
|
all_dataset = json.load(f)
|
|
|
|
anchors = []
|
|
positives = []
|
|
negatives_1 = []
|
|
negatives_2 = []
|
|
negatives_3 = []
|
|
negatives_4 = []
|
|
negatives_5 = []
|
|
for data in all_dataset:
|
|
anchors.append(data["question"])
|
|
positives.append(data["passage_positive"])
|
|
all_negatives = data["passage_negative"] + data["passage_negative_random"]
|
|
if len(all_negatives) < 5:
|
|
for i in range(5 - len(all_negatives)):
|
|
all_negatives.append(all_negatives[0])
|
|
negatives_1.append(all_negatives[0])
|
|
negatives_2.append(all_negatives[1])
|
|
negatives_3.append(all_negatives[2])
|
|
negatives_4.append(all_negatives[3])
|
|
negatives_5.append(all_negatives[4])
|
|
|
|
dataset = Dataset.from_dict({
|
|
"anchor": anchors,
|
|
"positive": positives,
|
|
"negative": negatives_1,
|
|
})
|
|
|
|
dataset_split = dataset.train_test_split(test_size=0.05, seed=42)
|
|
|
|
train_dataset = dataset_split["train"]
|
|
eval_dataset = dataset_split["test"]
|
|
|
|
|
|
loss = TripletLoss(model,
|
|
distance_metric=TripletDistanceMetric.COSINE,
|
|
triplet_margin=0.75)
|
|
|
|
dev_evaluator = TripletEvaluator(
|
|
anchors=eval_dataset["anchor"],
|
|
positives=eval_dataset["positive"],
|
|
negatives=eval_dataset["negative"],
|
|
main_similarity_function=SimilarityFunction.COSINE
|
|
)
|
|
|
|
training_args = SentenceTransformerTrainingArguments(
|
|
output_dir="save_dir",
|
|
num_train_epochs=1,
|
|
per_device_train_batch_size=2,
|
|
per_device_eval_batch_size=2,
|
|
learning_rate=2.5e-5,
|
|
warmup_ratio=0.1,
|
|
greater_is_better=True,
|
|
load_best_model_at_end = True,
|
|
metric_for_best_model="eval_cosine_accuracy",
|
|
fp16=False,
|
|
bf16=True,
|
|
eval_strategy="steps",
|
|
eval_steps=50,
|
|
save_strategy="steps",
|
|
save_steps=50,
|
|
save_total_limit=10,
|
|
logging_steps=50,
|
|
logging_first_step=True,
|
|
)
|
|
|
|
trainer = SentenceTransformerTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
loss=loss,
|
|
evaluator=dev_evaluator,
|
|
)
|
|
|
|
# pretraining_encoding = model.encode(["The human torch was denied a bank loan."])
|
|
# print("Pre-training encoding:", pretraining_encoding)
|
|
|
|
# Begine fine tuning
|
|
trainer.train()
|