embedding_model/evaluation/evaluate_with_religous_50000.py
2025-12-28 09:54:59 +00:00

156 lines
4.5 KiB
Python

import mteb
import numpy as np
import requests
import tqdm
from torch.utils.data import DataLoader
from mteb.encoder_interface import PromptType
from typing import Any
# from mteb.abstasks.task_metadata import TaskMetadata
# from mteb.models.models_protocols import EncoderProtocol
import json
import os
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from datasets.config import HF_DATASETS_CACHE
from huggingface_hub.utils import get_session
import numpy
import faiss
import numpy as np
class CustomModel:
def __init__(self):
self.session = requests.Session()
def get_embedding(self, sentece, prompt_name):
embedding_url = "http://127.0.0.1:5010/embed"
headers = {"accept": "application/json"}
headers["Content-Type"] = "application/json"
data = {}
data["inputs"] = sentece
data["normalize"] = True
data["prompt_name"] = prompt_name
data["truncate"] = False
data["truncation_direction"] = "Right"
response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600)
return response.json()
def get_simplexity_query2vec_results(self, sentences, template):
if len(sentences) < 2000:
my_range = range
else:
my_range = tqdm.trange
batch_size = 64
vec = []
for i in my_range(0, len(sentences), batch_size):
start_idx = i
stop_idx = min(i+batch_size, len(sentences))
new_vec = self.get_embedding(sentences[start_idx:stop_idx], template)
vec += new_vec
return vec
def encode(
self,
sentences: list[str],
task_name: str,
prompt_type: PromptType | None = None,
**kwargs,
) -> np.ndarray:
if prompt_type == None:
template = "document"
elif prompt_type == PromptType.query:
template = "query"
elif prompt_type == PromptType.document:
template = "document"
else:
raise Exception("Error: prompt_type")
all_embeddings = []
# all_texts = []
# for batch in inputs:
# all_texts += batch["text"]
# embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template)
# all_embeddings += embeddings
all_embeddings = self.get_simplexity_query2vec_results(sentences, template)
return numpy.array(all_embeddings)
def build_faiss_index(embeddings: np.ndarray) -> faiss.Index:
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim) # Inner Product = Cosine (بعد از نرمال‌سازی)
index.add(embeddings)
return index
def recall_at_k(index, query_embeddings, ground_truth, all_docs, k=10):
"""
ground_truth: لیستی از ایندکس جواب درست برای هر کوئری
"""
scores, retrieved = index.search(query_embeddings, k)
hits = 0
for i, gt in enumerate(ground_truth):
if gt in [all_docs[retrieved[i][j]] for j in range(len(retrieved[i]))]:
hits += 1
return hits / len(ground_truth)
def evaluate():
# model_name = "KaLM-embedding-multilingual-mini-instruct-v2.5"
# model_name = "KaLM-Embedding-Gemma3-12B-2511"
# model_name = "llama-embed-nemotron-8b"
# model_name = "embeddinggemma-300m"
model = CustomModel()
file_path = os.path.dirname(__file__)
with open(file_path + "/../data/dataset/religous_test_50000/test_religous.json", "r") as f:
data = json.load(f)
all_docs = [data[i]["passage_positive"][j] for i in range(len(data)) for j in range(len(data[i]["passage_positive"]))]
all_queries = [data[i]["question"] for i in range(len(data))]
num_docs = len(all_docs)
num_test = len(all_docs)
doc_embedding = model.encode(all_docs[0:num_docs], None, prompt_type="document")
index = build_faiss_index(doc_embedding)
query_embeddings = model.encode(all_queries[0:num_test], None, prompt_type="query")
for k in [1, 5, 10]:
ground_truth = [data[i]["passage_positive"][0] for i in range(num_test)]
r = recall_at_k(index, query_embeddings, ground_truth, all_docs, k)
print(f"Recall@{k}: {r:.6f}")
# finetune:
# Recall@1: 0.401740
# Recall@5: 0.614240
# Recall@10: 0.683000
def main():
# get_results()
evaluate()
if __name__ == "__main__":
main()