From 467c21ce7e12faf7719b42df38721a0ef4dafabc Mon Sep 17 00:00:00 2001 From: saeedfirouzi Date: Mon, 10 Nov 2025 15:32:25 +0000 Subject: [PATCH] train jina --- data_preprocess/preprocess_v1.py | 56 +++--- requirements.txt | 6 +- research_notebook/data_preprocess/bge.ipynb | 173 +++++++++++++++++++ research_notebook/data_preprocess/test.ipynb | 12 ++ train/jina/jina_train.py | 42 +++-- train/jina/test.py | 101 +++++++++++ train/jina/test_2.py | 109 ++++++++++++ train/jina/test_3.py | 141 +++++++++++++++ 8 files changed, 597 insertions(+), 43 deletions(-) create mode 100644 research_notebook/data_preprocess/bge.ipynb create mode 100644 train/jina/test.py create mode 100644 train/jina/test_2.py create mode 100644 train/jina/test_3.py diff --git a/data_preprocess/preprocess_v1.py b/data_preprocess/preprocess_v1.py index 67a0f29..c3bfbe4 100644 --- a/data_preprocess/preprocess_v1.py +++ b/data_preprocess/preprocess_v1.py @@ -150,40 +150,40 @@ def save_dataset(dataset, output_path): def main(output_path): - # #load synthetic dataset - # print("--------------------------------") - # print("loading synthetic dataset") - # synthetic_train_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/train.jsonl" - # synthetic_corpus_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/corpus.jsonl" - # synthetic_queries_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/queries.jsonl" + #load synthetic dataset + print("--------------------------------") + print("loading synthetic dataset") + synthetic_train_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/train.jsonl" + synthetic_corpus_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/corpus.jsonl" + synthetic_queries_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/queries.jsonl" - # synthetic_dataset = load_synthetic_dataset(synthetic_train_path, synthetic_queries_path, synthetic_corpus_path) - # print(f"synthetic dataset loaded : {len(synthetic_dataset)} samples") - # print("--------------------------------") + synthetic_dataset = load_synthetic_dataset(synthetic_train_path, synthetic_queries_path, synthetic_corpus_path) + print(f"synthetic dataset loaded : {len(synthetic_dataset)} samples") + print("--------------------------------") - # #load pquad dataset - # print("loading pquad dataset") - # pquad_dataset = load_pquad_dataset() - # print(f"pquad dataset loaded : {len(pquad_dataset)} samples") - # print("--------------------------------") + #load pquad dataset + print("loading pquad dataset") + pquad_dataset = load_pquad_dataset() + print(f"pquad dataset loaded : {len(pquad_dataset)} samples") + print("--------------------------------") - # # merge synthetic and pquad dataset - # print("start to merge synthetic and pquad dataset") - # all_dataset = synthetic_dataset + pquad_dataset - # print(f"successfully merged synthetic and pquad dataset") - # print("--------------------------------") + # merge synthetic and pquad dataset + print("start to merge synthetic and pquad dataset") + all_dataset = synthetic_dataset + pquad_dataset + print(f"successfully merged synthetic and pquad dataset") + print("--------------------------------") - # # removing false negative samples from all dataset - # print("start to remove false negative samples from all dataset") - # all_dataset = remove_false_negative(all_dataset, random_negative_sample=False) - # print(f"successfully removed false negative samples from all dataset") - # print("--------------------------------") + # removing false negative samples from all dataset + print("start to remove false negative samples from all dataset") + all_dataset = remove_false_negative(all_dataset, random_negative_sample=False) + print(f"successfully removed false negative samples from all dataset") + print("--------------------------------") - with open("/home/firouzi/embedding_model/data/train.json", "r", encoding="utf-8") as f: - all_dataset = json.load(f) + # with open("/home/firouzi/embedding_model/data/train.json", "r", encoding="utf-8") as f: + # all_dataset = json.load(f) - for i in range(len(all_dataset)): - all_dataset[i]['passage_negative_random'] = [] + # for i in range(len(all_dataset)): + # all_dataset[i]['passage_negative_random'] = [] #generate random negative samples print("start to generate random negative samples") diff --git a/requirements.txt b/requirements.txt index 12f88f9..25de51a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ python-dotenv==1.1.1 -hazm -faiss-cpu \ No newline at end of file +hazm=0.10.0 +faiss-cpu==1.12.0 +sentence-transformers==5.1.2 +einops==0.8.1 \ No newline at end of file diff --git a/research_notebook/data_preprocess/bge.ipynb b/research_notebook/data_preprocess/bge.ipynb new file mode 100644 index 0000000..5c98d62 --- /dev/null +++ b/research_notebook/data_preprocess/bge.ipynb @@ -0,0 +1,173 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "9dbad513", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Downloading readme: 100%|██████████| 419/419 [00:00<00:00, 1.18MB/s]\n", + "Downloading data: 100%|██████████| 1.59M/1.59M [00:01<00:00, 1.03MB/s]\n", + "Generating train split: 100%|██████████| 7000/7000 [00:00<00:00, 175360.77 examples/s]\n" + ] + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ds = load_dataset(\"virattt/financial-qa-10K\", split=\"train\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7330f385", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n", + " 'pos': 'Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.',\n", + " 'id': '0'}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds = ds.select_columns(column_names=[\"question\", \"context\"])\n", + "ds = ds.rename_column(\"question\", \"query\")\n", + "ds = ds.rename_column(\"context\", \"pos\")\n", + "ds = ds.add_column(\"id\", [str(i) for i in range(len(ds))])\n", + "ds[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5ba361dd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 7000/7000 [00:00<00:00, 19176.72 examples/s]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "np.random.seed(520)\n", + "neg_num = 10\n", + "\n", + "def str_to_lst(data):\n", + " data[\"pos\"] = [data[\"pos\"]]\n", + " return data\n", + "\n", + "# sample negative texts\n", + "new_col = []\n", + "for i in range(len(ds)):\n", + " ids = np.random.randint(0, len(ds), size=neg_num)\n", + " while i in ids:\n", + " ids = np.random.randint(0, len(ds), size=neg_num)\n", + " neg = [ds[i.item()][\"pos\"] for i in ids]\n", + " new_col.append(neg)\n", + "ds = ds.add_column(\"neg\", new_col)\n", + "\n", + "# change the key of 'pos' to a list\n", + "ds = ds.map(str_to_lst)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "bf3241ca", + "metadata": {}, + "outputs": [], + "source": [ + "instruction = \"Represent this sentence for searching relevant passages: \"\n", + "ds = ds.add_column(\"prompt\", [instruction]*len(ds))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a35c1466", + "metadata": {}, + "outputs": [], + "source": [ + "split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)\n", + "train = split[\"train\"]\n", + "test = split[\"test\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "24f3f7fb", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 26.22ba/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "16583481" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train.to_json(\"training.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5cc42ed", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/research_notebook/data_preprocess/test.ipynb b/research_notebook/data_preprocess/test.ipynb index 17e0371..3aaa0df 100644 --- a/research_notebook/data_preprocess/test.ipynb +++ b/research_notebook/data_preprocess/test.ipynb @@ -141,6 +141,18 @@ "id": "53e5e322", "metadata": {}, "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ds = load_dataset(\"virattt/financial-qa-10K\", split=\"train\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fabd9d8", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/train/jina/jina_train.py b/train/jina/jina_train.py index 4ab2179..cc0b1e6 100644 --- a/train/jina/jina_train.py +++ b/train/jina/jina_train.py @@ -1,4 +1,4 @@ -from datasets import Dataset +from datasets import Dataset, load_dataset, VerificationMode import json from sentence_transformers import ( SentenceTransformer, @@ -10,15 +10,18 @@ from sentence_transformers.training_args import BatchSamplers from sentence_transformers.evaluation import RerankingEvaluator +print("start") ########### Load model ########### +print("loading model") # 1. Load a model to finetune with 2. (Optional) model card data model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True, local_files_only=False, - model_kwargs={'default_task': 'retrieval'}) + model_kwargs={'default_task': 'retrieval.passage'}) ########### Load dataset ########### +print("loading dataset") # 3. Load a dataset to finetune on with open("/home/firouzi/embedding_model/data/train_100.json", "r", encoding="utf-8") as f: all_dataset = json.load(f) @@ -32,7 +35,7 @@ negatives_4 = [] negatives_5 = [] for data in all_dataset: anchors.append(data["question"]) - positives.append(data["passage_positive"]) + positives.append(data["passage_positive"][0]) all_negatives = data["passage_negative"] + data["passage_negative_random"] if len(all_negatives) < 5: for i in range(5 - len(all_negatives)): @@ -57,35 +60,44 @@ dataset_split = dataset.train_test_split(test_size=0.05, seed=42) train_dataset = dataset_split["train"] eval_dataset = dataset_split["test"] +# print(train_dataset[1]) + +# dataset = load_dataset("persiannlp/parsinlu_reading_comprehension", verification_mode=VerificationMode.NO_CHECKS) +# train_dataset = dataset["train"] +# print(train_dataset[1]) ########### Load loss function ########### +print("loading loss function") # 4. Define a loss function loss = MultipleNegativesRankingLoss(model) +# loss = ########### Load training arguments ########### +print("loading training arguments") # 5. (Optional) Specify training arguments args = SentenceTransformerTrainingArguments( # Required parameter: output_dir="models/jina_v3", # Optional training parameters: num_train_epochs=1, - per_device_train_batch_size=16, - per_device_eval_batch_size=16, + per_device_train_batch_size=4, + per_device_eval_batch_size=4, learning_rate=2e-5, warmup_ratio=0.1, fp16=True, # Set to False if you get an error that your GPU can't run on FP16 bf16=False, # Set to True if you have a GPU that supports BF16 - batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + batch_sampler=BatchSamplers.BATCH_SAMPLER, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch # Optional tracking/debugging parameters: - eval_strategy="steps", - eval_steps=100, + # eval_strategy="steps", + # eval_steps=5, save_strategy="steps", - save_steps=100, + save_steps=10, save_total_limit=2, - logging_steps=100, + logging_steps=5, run_name="jina_v3", # Will be used in W&B if `wandb` is installed ) ########### Load evaluator ########### +print("loading evaluator") # 6. (Optional) Create an evaluator & evaluate the base model eval_dataset_evaluator = [ { @@ -99,17 +111,18 @@ dev_evaluator = RerankingEvaluator( name="jina_v3", samples=eval_dataset_evaluator, ) -dev_evaluator(model) +# dev_evaluator(model) ########### Load trainer ########### +print("loading trainer") # 7. Create a trainer & train trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, - eval_dataset=eval_dataset, + # eval_dataset=eval_dataset, loss=loss, - evaluator=dev_evaluator, + # evaluator=dev_evaluator, ) trainer.train() @@ -124,5 +137,8 @@ trainer.train() # test_evaluator(model) ########### Save the trained model ########### +print("saving model") # 8. Save the trained model model.save_pretrained("models/jina_v3") +print("model saved") +print("end") diff --git a/train/jina/test.py b/train/jina/test.py new file mode 100644 index 0000000..6a577bc --- /dev/null +++ b/train/jina/test.py @@ -0,0 +1,101 @@ +from sentence_transformers import ( + SentenceTransformer, + InputExample, + SentenceTransformerTrainingArguments, + SentenceTransformerTrainer, +) +from sentence_transformers.losses import TripletLoss, MatryoshkaLoss, TripletDistanceMetric +from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction, SequentialEvaluator +from transformers import EarlyStoppingCallback +import torch +import os +import json +from datasets import Dataset + +model = SentenceTransformer("jinaai/jina-embeddings-v3", + trust_remote_code=True, + local_files_only=False, + model_kwargs={'default_task': 'text-matching'}) + +print("loading dataset") +# 3. Load a dataset to finetune on +with open("/home/firouzi/embedding_model/data/train_100.json", "r", encoding="utf-8") as f: + all_dataset = json.load(f) + +anchors = [] +positives = [] +negatives_1 = [] +negatives_2 = [] +negatives_3 = [] +negatives_4 = [] +negatives_5 = [] +for data in all_dataset: + anchors.append(data["question"]) + positives.append(data["passage_positive"]) + all_negatives = data["passage_negative"] + data["passage_negative_random"] + if len(all_negatives) < 5: + for i in range(5 - len(all_negatives)): + all_negatives.append(all_negatives[0]) + negatives_1.append(all_negatives[0]) + negatives_2.append(all_negatives[1]) + negatives_3.append(all_negatives[2]) + negatives_4.append(all_negatives[3]) + negatives_5.append(all_negatives[4]) + +dataset = Dataset.from_dict({ + "anchor": anchors, + "positive": positives, + "negative": negatives_1, +}) + +dataset_split = dataset.train_test_split(test_size=0.05, seed=42) + +train_dataset = dataset_split["train"] +eval_dataset = dataset_split["test"] + + +loss = TripletLoss(model, + distance_metric=TripletDistanceMetric.COSINE, + triplet_margin=0.75) + +dev_evaluator = TripletEvaluator( + anchors=eval_dataset["anchor"], + positives=eval_dataset["positive"], + negatives=eval_dataset["negative"], + main_similarity_function=SimilarityFunction.COSINE +) + +training_args = SentenceTransformerTrainingArguments( + output_dir="save_dir", + num_train_epochs=1, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + learning_rate=2.5e-5, + warmup_ratio=0.1, + greater_is_better=True, + load_best_model_at_end = True, + metric_for_best_model="eval_cosine_accuracy", + fp16=False, + bf16=True, + eval_strategy="steps", + eval_steps=50, + save_strategy="steps", + save_steps=50, + save_total_limit=10, + logging_steps=50, + logging_first_step=True, +) + +trainer = SentenceTransformerTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + loss=loss, + evaluator=dev_evaluator, +) + +# pretraining_encoding = model.encode(["The human torch was denied a bank loan."]) +# print("Pre-training encoding:", pretraining_encoding) + +# Begine fine tuning +trainer.train() diff --git a/train/jina/test_2.py b/train/jina/test_2.py new file mode 100644 index 0000000..220551c --- /dev/null +++ b/train/jina/test_2.py @@ -0,0 +1,109 @@ +from datasets import Dataset +import json +from sentence_transformers import ( + SentenceTransformer, + SentenceTransformerTrainer, + SentenceTransformerTrainingArguments, +) +from sentence_transformers.losses import MultipleNegativesRankingLoss +from sentence_transformers.training_args import BatchSamplers +from sentence_transformers.evaluation import RerankingEvaluator + + +print("start") +########### Load model ########### +print("loading model") +# 1. Load a model to finetune with 2. (Optional) model card data +model = SentenceTransformer("jinaai/jina-embeddings-v3", + trust_remote_code=True, + local_files_only=False, + model_kwargs={'default_task': 'retrieval.passage'}) + + + +########### Load dataset ########### +print("loading dataset") +# 3. Load a dataset to finetune on +with open("/home/firouzi/embedding_model/data/train_100.json", "r", encoding="utf-8") as f: + all_dataset = json.load(f) + +# MultipleNegativesRankingLoss expects InputExample(texts=[anchor, positive]) +# Your explicit negatives will be ignored, and in-batch negatives will be used. +from sentence_transformers.data import InputExample +from sklearn.model_selection import train_test_split + +all_examples = [] +for data in all_dataset: + all_examples.append(InputExample(texts=[data["question"], data["passage_positive"]])) + +# Split the dataset into train and evaluation +train_examples, eval_examples = train_test_split(all_examples, test_size=0.05, random_state=42) + +print(f"Training with {len(train_examples)} examples") +print(f"Evaluating with {len(eval_examples)} examples") + +########### Load loss function ########### +print("loading loss function") +# 4. Define a loss function +loss = MultipleNegativesRankingLoss(model) + +########### Load evaluator ########### +print("loading evaluator") +# 6. (Optional) Create an evaluator +# The evaluator format you had was correct, but we need to build it from the 'eval_examples' +eval_dataset_evaluator = [] +for data in all_dataset: # We can still use all_dataset to find the matching negatives + example_query = data["question"] + example_positive = data["passage_positive"] + + # Find if this example is in our eval set + is_in_eval = False + for eval_ex in eval_examples: + if eval_ex.texts[0] == example_query and eval_ex.texts[1] == example_positive: + is_in_eval = True + break + + if is_in_eval: + all_negatives = data["passage_negative"] + data["passage_negative_random"] + if len(all_negatives) < 5: + for i in range(5 - len(all_negatives)): + all_negatives.append(all_negatives[0]) # Pad negatives + + eval_dataset_evaluator.append({ + "query": example_query, + "positive": [example_positive], + "negative": all_negatives[:5], # Use your original negatives for evaluation + }) + +dev_evaluator = RerankingEvaluator( + name="jina_v3", + samples=eval_dataset_evaluator, +) +# dev_evaluator(model) # You can still run this to check base performance + +########### Train the model ########### +print("starting training with model.fit()") +from torch.utils.data import DataLoader + +# Create a DataLoader for the training examples +train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=4) + +# 7. Train the model using model.fit() +model.fit( + train_objectives=[(train_dataloader, loss)], + evaluator=dev_evaluator, + epochs=1, + evaluation_steps=5, + warmup_steps=int(len(train_dataloader) * 0.1), # 10% warmup + output_path="models/jina_v3", + save_best_model=True, + show_progress_bar=True, + use_amp=True, # Replaces fp16=True +) + +########### Save the trained model ########### +# model.fit() already saves the best model to output_path, but you can save again +print("saving final model") +model.save_pretrained("models/jina_v3_final") +print("model saved") +print("end") \ No newline at end of file diff --git a/train/jina/test_3.py b/train/jina/test_3.py new file mode 100644 index 0000000..b798a6a --- /dev/null +++ b/train/jina/test_3.py @@ -0,0 +1,141 @@ +from datasets import Dataset +import json +from sentence_transformers import ( + SentenceTransformer, + SentenceTransformerTrainer, + SentenceTransformerTrainingArguments, +) +from sentence_transformers.losses import MultipleNegativesRankingLoss +from sentence_transformers.training_args import BatchSamplers +from sentence_transformers.evaluation import RerankingEvaluator + + +print("start") +########### Load model ########### +print("loading model") +# 1. Load a model to finetune with 2. (Optional) model card data +model = SentenceTransformer("jinaai/jina-embeddings-v3", + trust_remote_code=True, + local_files_only=False, + model_kwargs={'default_task': 'retrieval.passage'}) + + +########### Load dataset ########### +print("loading dataset") +# 3. Load a dataset to finetune on +with open("/home/firouzi/embedding_model/data/train_100.json", "r", encoding="utf-8") as f: + all_dataset = json.load(f) + +anchors = [] +positives = [] +negatives_1 = [] +negatives_2 = [] +negatives_3 = [] +negatives_4 = [] +negatives_5 = [] +for data in all_dataset: + anchors.append(data["question"]) + positives.append(data["passage_positive"]) + all_negatives = data["passage_negative"] + data["passage_negative_random"] + if len(all_negatives) < 5: + for i in range(5 - len(all_negatives)): + all_negatives.append(all_negatives[0]) + negatives_1.append(all_negatives[0]) + negatives_2.append(all_negatives[1]) + negatives_3.append(all_negatives[2]) + negatives_4.append(all_negatives[3]) + negatives_5.append(all_negatives[4]) + +dataset = Dataset.from_dict({ + "anchor": anchors, + "positive": positives, + "negative_1": negatives_1, + "negative_2": negatives_2, + "negative_3": negatives_3, + "negative_4": negatives_4, + "negative_5": negatives_5, +}) + +dataset_split = dataset.train_test_split(test_size=0.05, seed=42) + +train_dataset = dataset_split["train"] +eval_dataset = dataset_split["test"] +########### Load loss function ########### +print("loading loss function") +# 4. Define a loss function +loss = MultipleNegativesRankingLoss(model) + +########### Load training arguments ########### +print("loading training arguments") +# 5. (Optional) Specify training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir="models/jina_v3", + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=4, + per_device_eval_batch_size=4, + learning_rate=2e-5, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=5, + save_strategy="steps", + save_steps=10, + save_total_limit=2, + logging_steps=5, + run_name="jina_v3", # Will be used in W&B if `wandb` is installed +) + +########### Load evaluator ########### +print("loading evaluator") +# 6. (Optional) Create an evaluator & evaluate the base model +eval_dataset_evaluator = [ + { + "query": sample["anchor"], + "positive": [sample["positive"]], + "negative": [sample["negative_1"], sample["negative_2"], sample["negative_3"], sample["negative_4"], sample["negative_5"]], + } + for sample in eval_dataset +] +dev_evaluator = RerankingEvaluator( + name="jina_v3", + samples=eval_dataset_evaluator, +) +# dev_evaluator(model) + +########### Load trainer ########### +print("loading trainer") +# 7. Create a trainer & train +# 7. Train the model using model.fit() +model.fit( + train_objectives=[(train_dataset, loss)], + evaluator=dev_evaluator, + epochs=1, + evaluation_steps=5, + warmup_steps=int(len(train_dataset) * 0.1), # 10% warmup + output_path="models/jina_v3", + save_best_model=True, + show_progress_bar=True, + use_amp=True, # Replaces fp16=True +) + +########### Load test evaluator ########### +# (Optional) Evaluate the trained model on the test set +# test_evaluator = TripletEvaluator( +# anchors=test_dataset["anchor"], +# positives=test_dataset["positive"], +# negatives=test_dataset["negative"], +# name="all-nli-test", +# ) +# test_evaluator(model) + +########### Save the trained model ########### +print("saving model") +# 8. Save the trained model +model.save_pretrained("models/jina_v3") +print("model saved") +print("end")