2025-11-10 15:32:25 +00:00

109 lines
3.8 KiB
Python

from datasets import Dataset
import json
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import RerankingEvaluator
print("start")
########### Load model ###########
print("loading model")
# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer("jinaai/jina-embeddings-v3",
trust_remote_code=True,
local_files_only=False,
model_kwargs={'default_task': 'retrieval.passage'})
########### Load dataset ###########
print("loading dataset")
# 3. Load a dataset to finetune on
with open("/home/firouzi/embedding_model/data/train_100.json", "r", encoding="utf-8") as f:
all_dataset = json.load(f)
# MultipleNegativesRankingLoss expects InputExample(texts=[anchor, positive])
# Your explicit negatives will be ignored, and in-batch negatives will be used.
from sentence_transformers.data import InputExample
from sklearn.model_selection import train_test_split
all_examples = []
for data in all_dataset:
all_examples.append(InputExample(texts=[data["question"], data["passage_positive"]]))
# Split the dataset into train and evaluation
train_examples, eval_examples = train_test_split(all_examples, test_size=0.05, random_state=42)
print(f"Training with {len(train_examples)} examples")
print(f"Evaluating with {len(eval_examples)} examples")
########### Load loss function ###########
print("loading loss function")
# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)
########### Load evaluator ###########
print("loading evaluator")
# 6. (Optional) Create an evaluator
# The evaluator format you had was correct, but we need to build it from the 'eval_examples'
eval_dataset_evaluator = []
for data in all_dataset: # We can still use all_dataset to find the matching negatives
example_query = data["question"]
example_positive = data["passage_positive"]
# Find if this example is in our eval set
is_in_eval = False
for eval_ex in eval_examples:
if eval_ex.texts[0] == example_query and eval_ex.texts[1] == example_positive:
is_in_eval = True
break
if is_in_eval:
all_negatives = data["passage_negative"] + data["passage_negative_random"]
if len(all_negatives) < 5:
for i in range(5 - len(all_negatives)):
all_negatives.append(all_negatives[0]) # Pad negatives
eval_dataset_evaluator.append({
"query": example_query,
"positive": [example_positive],
"negative": all_negatives[:5], # Use your original negatives for evaluation
})
dev_evaluator = RerankingEvaluator(
name="jina_v3",
samples=eval_dataset_evaluator,
)
# dev_evaluator(model) # You can still run this to check base performance
########### Train the model ###########
print("starting training with model.fit()")
from torch.utils.data import DataLoader
# Create a DataLoader for the training examples
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=4)
# 7. Train the model using model.fit()
model.fit(
train_objectives=[(train_dataloader, loss)],
evaluator=dev_evaluator,
epochs=1,
evaluation_steps=5,
warmup_steps=int(len(train_dataloader) * 0.1), # 10% warmup
output_path="models/jina_v3",
save_best_model=True,
show_progress_bar=True,
use_amp=True, # Replaces fp16=True
)
########### Save the trained model ###########
# model.fit() already saves the best model to output_path, but you can save again
print("saving final model")
model.save_pretrained("models/jina_v3_final")
print("model saved")
print("end")