95 lines
3.2 KiB
Python
95 lines
3.2 KiB
Python
from sentence_transformers import SentenceTransformer
|
|
import json
|
|
from datasets import Dataset
|
|
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
|
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
|
import argparse
|
|
from peft import LoraConfig, TaskType
|
|
from transformers import TrainerCallback
|
|
|
|
|
|
def get_ndcg(model, dataset):
|
|
query_embeddings = model.encode_query("hey")
|
|
print(query_embeddings[:20])
|
|
|
|
|
|
def main(add_prompt, lora):
|
|
########### Load dataset ###########
|
|
print("loading dataset")
|
|
with open("/home/firouzi/embedding_model/data/dataset_train.json", "r", encoding="utf-8") as f:
|
|
all_dataset = json.load(f)
|
|
|
|
query_prompt = "task: search result | query: "
|
|
document_prompt = "title: none | text: "
|
|
|
|
data_as_dicts = []
|
|
for data in all_dataset:
|
|
for data_neg in (data["passage_negative"] + data["passage_negative_random"]):
|
|
if add_prompt:
|
|
data_as_dicts.append({"anchor": query_prompt + data["question"], "positive": document_prompt + data["passage_positive"][0], "negative": document_prompt + data_neg})
|
|
else:
|
|
data_as_dicts.append({"anchor": data["question"], "positive": data["passage_positive"][0], "negative": data_neg})
|
|
|
|
train_dataset = Dataset.from_list(data_as_dicts)
|
|
print(f"len train_dataset: {len(train_dataset)}")
|
|
|
|
####################################
|
|
print("loading model")
|
|
model = SentenceTransformer("google/embeddinggemma-300M").to(device="cuda:0")
|
|
|
|
if lora:
|
|
# Create a LoRA adapter for the model
|
|
peft_config = LoraConfig(
|
|
task_type=TaskType.FEATURE_EXTRACTION,
|
|
inference_mode=False,
|
|
r=64,
|
|
lora_alpha=128,
|
|
lora_dropout=0.1,
|
|
)
|
|
model.add_adapter(peft_config)
|
|
|
|
loss = MultipleNegativesRankingLoss(model)
|
|
|
|
args = SentenceTransformerTrainingArguments(
|
|
output_dir="./models/gemma",
|
|
num_train_epochs=1,
|
|
per_device_train_batch_size=32,
|
|
learning_rate=2e-5,
|
|
warmup_ratio=0.05,
|
|
logging_steps=train_dataset.num_rows,
|
|
report_to="none",
|
|
save_steps=10000,
|
|
save_total_limit=2,
|
|
)
|
|
|
|
class MyCallback(TrainerCallback):
|
|
"A callback that evaluates the model at the end of eopch"
|
|
def __init__(self, evaluate):
|
|
self.evaluate = evaluate # evaluate function
|
|
|
|
def on_log(self, args, state, control, **kwargs):
|
|
# Evaluate the model using text generation
|
|
print(f"Step {state.global_step} finished. Running evaluation:")
|
|
self.evaluate()
|
|
|
|
def evaluate():
|
|
get_ndcg(model, train_dataset)
|
|
|
|
print("start to training model...")
|
|
trainer = SentenceTransformerTrainer(
|
|
model=model,
|
|
args=args,
|
|
train_dataset=train_dataset,
|
|
loss=loss,
|
|
# callbacks=[MyCallback(evaluate)]
|
|
)
|
|
trainer.train()
|
|
print("training done")
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--add_prompt", action="store_true")
|
|
parser.add_argument("--lora", action="store_true")
|
|
args = parser.parse_args()
|
|
print(args.lora)
|
|
main(args.add_prompt, args.lora) |