150 lines
4.4 KiB
Python
150 lines
4.4 KiB
Python
import mteb
|
|
import numpy as np
|
|
import requests
|
|
import tqdm
|
|
from torch.utils.data import DataLoader
|
|
from mteb.encoder_interface import PromptType
|
|
from typing import Any
|
|
# from mteb.abstasks.task_metadata import TaskMetadata
|
|
# from mteb.models.models_protocols import EncoderProtocol
|
|
import json
|
|
import os
|
|
from sentence_transformers import SentenceTransformer
|
|
from datasets import load_dataset
|
|
from datasets.config import HF_DATASETS_CACHE
|
|
from huggingface_hub.utils import get_session
|
|
import numpy
|
|
import faiss
|
|
import numpy as np
|
|
|
|
|
|
|
|
class CustomModel:
|
|
def __init__(self, model):
|
|
self.session = requests.Session()
|
|
self.model = model
|
|
|
|
|
|
def get_simplexity_query2vec_results(self, sentences, embedding_url, model, template):
|
|
params = {}
|
|
params["model"] = model
|
|
params["template"] = template
|
|
headers = {"accept": "application/json"}
|
|
data = {}
|
|
|
|
if len(sentences) < 2000:
|
|
my_range = range
|
|
else:
|
|
my_range = tqdm.trange
|
|
|
|
batch_size = 1024
|
|
vec = []
|
|
for i in my_range(0, len(sentences), batch_size):
|
|
start_idx = i
|
|
stop_idx = min(i+batch_size, len(sentences))
|
|
data["queries"] = sentences[start_idx:stop_idx]
|
|
response = self.session.post(embedding_url, headers=headers, params=params, data=json.dumps(data), timeout=600)
|
|
new_vec = response.json()
|
|
vec += new_vec
|
|
return vec
|
|
|
|
|
|
def encode(
|
|
self,
|
|
sentences: list[str],
|
|
task_name: str,
|
|
prompt_type: PromptType | None = None,
|
|
**kwargs,
|
|
) -> np.ndarray:
|
|
|
|
embedding_url = "http://127.0.0.1:5015/embedding"
|
|
|
|
if prompt_type == None:
|
|
template = "document"
|
|
elif prompt_type == PromptType.query:
|
|
template = "query"
|
|
elif prompt_type == PromptType.document:
|
|
template = "document"
|
|
else:
|
|
raise Exception("Error: prompt_type")
|
|
|
|
all_embeddings = []
|
|
# all_texts = []
|
|
# for batch in inputs:
|
|
# all_texts += batch["text"]
|
|
# embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template)
|
|
|
|
# all_embeddings += embeddings
|
|
all_embeddings = self.get_simplexity_query2vec_results(sentences, embedding_url, self.model, template)
|
|
|
|
|
|
return numpy.array(all_embeddings)
|
|
|
|
|
|
def build_faiss_index(embeddings: np.ndarray) -> faiss.Index:
|
|
dim = embeddings.shape[1]
|
|
index = faiss.IndexFlatIP(dim) # Inner Product = Cosine (بعد از نرمالسازی)
|
|
index.add(embeddings)
|
|
return index
|
|
|
|
|
|
def recall_at_k(index, query_embeddings, ground_truth, all_docs, k=10):
|
|
"""
|
|
ground_truth: لیستی از ایندکس جواب درست برای هر کوئری
|
|
"""
|
|
scores, retrieved = index.search(query_embeddings, k)
|
|
|
|
hits = 0
|
|
for i, gt in enumerate(ground_truth):
|
|
if gt in [all_docs[retrieved[i][j]] for j in range(len(retrieved[i]))]:
|
|
hits += 1
|
|
|
|
return hits / len(ground_truth)
|
|
|
|
def evaluate():
|
|
model_name = "Qwen3-Embedding-0.6B"
|
|
# model_name = "KaLM-embedding-multilingual-mini-instruct-v2.5"
|
|
# model_name = "KaLM-Embedding-Gemma3-12B-2511"
|
|
# model_name = "llama-embed-nemotron-8b"
|
|
# model_name = "embeddinggemma-300m"
|
|
model = CustomModel(model_name)
|
|
|
|
file_path = os.path.dirname(__file__)
|
|
|
|
with open(file_path + "/../data/dataset/religous_test_50000/test_religous.json", "r") as f:
|
|
data = json.load(f)
|
|
|
|
|
|
|
|
all_docs = [data[i]["passage_positive"][j] for i in range(len(data)) for j in range(len(data[i]["passage_positive"]))]
|
|
all_queries = [data[i]["question"] for i in range(len(data))]
|
|
|
|
|
|
num_docs = len(all_docs)
|
|
num_test = len(all_docs)
|
|
|
|
doc_embedding = model.encode(all_docs[0:num_docs], None, prompt_type="document")
|
|
index = build_faiss_index(doc_embedding)
|
|
|
|
|
|
|
|
|
|
query_embeddings = model.encode(all_queries[0:num_test], None, prompt_type="query")
|
|
for k in [1, 5, 10]:
|
|
ground_truth = [data[i]["passage_positive"][0] for i in range(num_test)]
|
|
r = recall_at_k(index, query_embeddings, ground_truth, all_docs, k)
|
|
print(f"Recall@{k}: {r:.6f}")
|
|
|
|
# finetune:
|
|
# Recall@1: 0.401740
|
|
# Recall@5: 0.614240
|
|
# Recall@10: 0.683000
|
|
|
|
|
|
def main():
|
|
# get_results()
|
|
evaluate()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |