embedding_model/train/gemma/gemma_train.py
2025-11-11 15:02:47 +00:00

95 lines
3.2 KiB
Python

from sentence_transformers import SentenceTransformer
import json
from datasets import Dataset
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
import argparse
from peft import LoraConfig, TaskType
from transformers import TrainerCallback
def get_ndcg(model, dataset):
query_embeddings = model.encode_query("hey")
print(query_embeddings[:20])
def main(add_prompt, lora):
########### Load dataset ###########
print("loading dataset")
with open("/home/firouzi/embedding_model/data/dataset_train.json", "r", encoding="utf-8") as f:
all_dataset = json.load(f)
query_prompt = "task: search result | query: "
document_prompt = "title: none | text: "
data_as_dicts = []
for data in all_dataset:
for data_neg in (data["passage_negative"] + data["passage_negative_random"]):
if add_prompt:
data_as_dicts.append({"anchor": query_prompt + data["question"], "positive": document_prompt + data["passage_positive"][0], "negative": document_prompt + data_neg})
else:
data_as_dicts.append({"anchor": data["question"], "positive": data["passage_positive"][0], "negative": data_neg})
train_dataset = Dataset.from_list(data_as_dicts)
print(f"len train_dataset: {len(train_dataset)}")
####################################
print("loading model")
model = SentenceTransformer("google/embeddinggemma-300M").to(device="cuda:0")
if lora:
# Create a LoRA adapter for the model
peft_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
r=64,
lora_alpha=128,
lora_dropout=0.1,
)
model.add_adapter(peft_config)
loss = MultipleNegativesRankingLoss(model)
args = SentenceTransformerTrainingArguments(
output_dir="./models/gemma",
num_train_epochs=1,
per_device_train_batch_size=32,
learning_rate=2e-5,
warmup_ratio=0.05,
logging_steps=train_dataset.num_rows,
report_to="none",
save_steps=10000,
save_total_limit=2,
)
class MyCallback(TrainerCallback):
"A callback that evaluates the model at the end of eopch"
def __init__(self, evaluate):
self.evaluate = evaluate # evaluate function
def on_log(self, args, state, control, **kwargs):
# Evaluate the model using text generation
print(f"Step {state.global_step} finished. Running evaluation:")
self.evaluate()
def evaluate():
get_ndcg(model, train_dataset)
print("start to training model...")
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss,
# callbacks=[MyCallback(evaluate)]
)
trainer.train()
print("training done")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--add_prompt", action="store_true")
parser.add_argument("--lora", action="store_true")
args = parser.parse_args()
print(args.lora)
main(args.add_prompt, args.lora)