embedding_model/train/jina/jina_train.py
2025-11-10 15:32:25 +00:00

145 lines
4.5 KiB
Python

from datasets import Dataset, load_dataset, VerificationMode
import json
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import RerankingEvaluator
print("start")
########### Load model ###########
print("loading model")
# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer("jinaai/jina-embeddings-v3",
trust_remote_code=True,
local_files_only=False,
model_kwargs={'default_task': 'retrieval.passage'})
########### Load dataset ###########
print("loading dataset")
# 3. Load a dataset to finetune on
with open("/home/firouzi/embedding_model/data/train_100.json", "r", encoding="utf-8") as f:
all_dataset = json.load(f)
anchors = []
positives = []
negatives_1 = []
negatives_2 = []
negatives_3 = []
negatives_4 = []
negatives_5 = []
for data in all_dataset:
anchors.append(data["question"])
positives.append(data["passage_positive"][0])
all_negatives = data["passage_negative"] + data["passage_negative_random"]
if len(all_negatives) < 5:
for i in range(5 - len(all_negatives)):
all_negatives.append(all_negatives[0])
negatives_1.append(all_negatives[0])
negatives_2.append(all_negatives[1])
negatives_3.append(all_negatives[2])
negatives_4.append(all_negatives[3])
negatives_5.append(all_negatives[4])
dataset = Dataset.from_dict({
"anchor": anchors,
"positive": positives,
"negative_1": negatives_1,
"negative_2": negatives_2,
"negative_3": negatives_3,
"negative_4": negatives_4,
"negative_5": negatives_5,
})
dataset_split = dataset.train_test_split(test_size=0.05, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
# print(train_dataset[1])
# dataset = load_dataset("persiannlp/parsinlu_reading_comprehension", verification_mode=VerificationMode.NO_CHECKS)
# train_dataset = dataset["train"]
# print(train_dataset[1])
########### Load loss function ###########
print("loading loss function")
# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)
# loss =
########### Load training arguments ###########
print("loading training arguments")
# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir="models/jina_v3",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.BATCH_SAMPLER, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
# eval_strategy="steps",
# eval_steps=5,
save_strategy="steps",
save_steps=10,
save_total_limit=2,
logging_steps=5,
run_name="jina_v3", # Will be used in W&B if `wandb` is installed
)
########### Load evaluator ###########
print("loading evaluator")
# 6. (Optional) Create an evaluator & evaluate the base model
eval_dataset_evaluator = [
{
"query": sample["anchor"],
"positive": [sample["positive"]],
"negative": [sample["negative_1"], sample["negative_2"], sample["negative_3"], sample["negative_4"], sample["negative_5"]],
}
for sample in eval_dataset
]
dev_evaluator = RerankingEvaluator(
name="jina_v3",
samples=eval_dataset_evaluator,
)
# dev_evaluator(model)
########### Load trainer ###########
print("loading trainer")
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
# eval_dataset=eval_dataset,
loss=loss,
# evaluator=dev_evaluator,
)
trainer.train()
########### Load test evaluator ###########
# (Optional) Evaluate the trained model on the test set
# test_evaluator = TripletEvaluator(
# anchors=test_dataset["anchor"],
# positives=test_dataset["positive"],
# negatives=test_dataset["negative"],
# name="all-nli-test",
# )
# test_evaluator(model)
########### Save the trained model ###########
print("saving model")
# 8. Save the trained model
model.save_pretrained("models/jina_v3")
print("model saved")
print("end")