from sentence_transformers import SentenceTransformer import json from datasets import Dataset from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.losses import MultipleNegativesRankingLoss import argparse from peft import LoraConfig, TaskType from transformers import TrainerCallback def get_ndcg(model, dataset): query_embeddings = model.encode_query("hey") print(query_embeddings[:20]) def main(add_prompt, lora): ########### Load dataset ########### print("loading dataset") with open("/home/firouzi/embedding_model/data/dataset_train.json", "r", encoding="utf-8") as f: all_dataset = json.load(f) query_prompt = "task: search result | query: " document_prompt = "title: none | text: " data_as_dicts = [] for data in all_dataset: for data_neg in (data["passage_negative"] + data["passage_negative_random"]): if add_prompt: data_as_dicts.append({"anchor": query_prompt + data["question"], "positive": document_prompt + data["passage_positive"][0], "negative": document_prompt + data_neg}) else: data_as_dicts.append({"anchor": data["question"], "positive": data["passage_positive"][0], "negative": data_neg}) train_dataset = Dataset.from_list(data_as_dicts) print(f"len train_dataset: {len(train_dataset)}") #################################### print("loading model") model = SentenceTransformer("google/embeddinggemma-300M").to(device="cuda:0") if lora: # Create a LoRA adapter for the model peft_config = LoraConfig( task_type=TaskType.FEATURE_EXTRACTION, inference_mode=False, r=64, lora_alpha=128, lora_dropout=0.1, ) model.add_adapter(peft_config) loss = MultipleNegativesRankingLoss(model) args = SentenceTransformerTrainingArguments( output_dir="./models/gemma", num_train_epochs=1, per_device_train_batch_size=32, learning_rate=2e-5, warmup_ratio=0.05, logging_steps=train_dataset.num_rows, report_to="none", save_steps=10000, save_total_limit=2, ) class MyCallback(TrainerCallback): "A callback that evaluates the model at the end of eopch" def __init__(self, evaluate): self.evaluate = evaluate # evaluate function def on_log(self, args, state, control, **kwargs): # Evaluate the model using text generation print(f"Step {state.global_step} finished. Running evaluation:") self.evaluate() def evaluate(): get_ndcg(model, train_dataset) print("start to training model...") trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, # callbacks=[MyCallback(evaluate)] ) trainer.train() print("training done") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--add_prompt", action="store_true") parser.add_argument("--lora", action="store_true") args = parser.parse_args() print(args.lora) main(args.add_prompt, args.lora)