from datasets import Dataset, load_dataset, VerificationMode import json from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.losses import MultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers from sentence_transformers.evaluation import RerankingEvaluator print("start") ########### Load model ########### print("loading model") # 1. Load a model to finetune with 2. (Optional) model card data model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True, local_files_only=False, model_kwargs={'default_task': 'retrieval.passage'}) ########### Load dataset ########### print("loading dataset") # 3. Load a dataset to finetune on with open("/home/firouzi/embedding_model/data/train_100.json", "r", encoding="utf-8") as f: all_dataset = json.load(f) anchors = [] positives = [] negatives_1 = [] negatives_2 = [] negatives_3 = [] negatives_4 = [] negatives_5 = [] for data in all_dataset: anchors.append(data["question"]) positives.append(data["passage_positive"][0]) all_negatives = data["passage_negative"] + data["passage_negative_random"] if len(all_negatives) < 5: for i in range(5 - len(all_negatives)): all_negatives.append(all_negatives[0]) negatives_1.append(all_negatives[0]) negatives_2.append(all_negatives[1]) negatives_3.append(all_negatives[2]) negatives_4.append(all_negatives[3]) negatives_5.append(all_negatives[4]) dataset = Dataset.from_dict({ "anchor": anchors, "positive": positives, "negative_1": negatives_1, "negative_2": negatives_2, "negative_3": negatives_3, "negative_4": negatives_4, "negative_5": negatives_5, }) dataset_split = dataset.train_test_split(test_size=0.05, seed=42) train_dataset = dataset_split["train"] eval_dataset = dataset_split["test"] # print(train_dataset[1]) # dataset = load_dataset("persiannlp/parsinlu_reading_comprehension", verification_mode=VerificationMode.NO_CHECKS) # train_dataset = dataset["train"] # print(train_dataset[1]) ########### Load loss function ########### print("loading loss function") # 4. Define a loss function loss = MultipleNegativesRankingLoss(model) # loss = ########### Load training arguments ########### print("loading training arguments") # 5. (Optional) Specify training arguments args = SentenceTransformerTrainingArguments( # Required parameter: output_dir="models/jina_v3", # Optional training parameters: num_train_epochs=1, per_device_train_batch_size=4, per_device_eval_batch_size=4, learning_rate=2e-5, warmup_ratio=0.1, fp16=True, # Set to False if you get an error that your GPU can't run on FP16 bf16=False, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.BATCH_SAMPLER, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch # Optional tracking/debugging parameters: # eval_strategy="steps", # eval_steps=5, save_strategy="steps", save_steps=10, save_total_limit=2, logging_steps=5, run_name="jina_v3", # Will be used in W&B if `wandb` is installed ) ########### Load evaluator ########### print("loading evaluator") # 6. (Optional) Create an evaluator & evaluate the base model eval_dataset_evaluator = [ { "query": sample["anchor"], "positive": [sample["positive"]], "negative": [sample["negative_1"], sample["negative_2"], sample["negative_3"], sample["negative_4"], sample["negative_5"]], } for sample in eval_dataset ] dev_evaluator = RerankingEvaluator( name="jina_v3", samples=eval_dataset_evaluator, ) # dev_evaluator(model) ########### Load trainer ########### print("loading trainer") # 7. Create a trainer & train trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, # eval_dataset=eval_dataset, loss=loss, # evaluator=dev_evaluator, ) trainer.train() ########### Load test evaluator ########### # (Optional) Evaluate the trained model on the test set # test_evaluator = TripletEvaluator( # anchors=test_dataset["anchor"], # positives=test_dataset["positive"], # negatives=test_dataset["negative"], # name="all-nli-test", # ) # test_evaluator(model) ########### Save the trained model ########### print("saving model") # 8. Save the trained model model.save_pretrained("models/jina_v3") print("model saved") print("end")