From ca6548961f64b6a1592d48eb225175e859b6746d Mon Sep 17 00:00:00 2001 From: hediehloo Date: Tue, 11 Nov 2025 07:57:00 +0000 Subject: [PATCH] add evaluation --- evaluation/evaluation.py | 146 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 evaluation/evaluation.py diff --git a/evaluation/evaluation.py b/evaluation/evaluation.py new file mode 100644 index 0000000..c4ff17d --- /dev/null +++ b/evaluation/evaluation.py @@ -0,0 +1,146 @@ +import mteb +import numpy as np +import requests +import tqdm +from torch.utils.data import DataLoader +from mteb.encoder_interface import PromptType +from typing import Any +# from mteb.abstasks.task_metadata import TaskMetadata +# from mteb.models.models_protocols import EncoderProtocol +import json +import os +from sentence_transformers import SentenceTransformer +from datasets import load_dataset +from datasets.config import HF_DATASETS_CACHE +from huggingface_hub.utils import get_session +import numpy + + + + +class CustomModel: + def __init__(self, model): + self.session = requests.Session() + self.model = model + + + def get_simplexity_query2vec_results(self, sentences, embedding_url, model, template): + params = {} + params["model"] = model + params["template"] = template + headers = {"accept": "application/json"} + data = {} + + if len(sentences) < 2000: + my_range = range + else: + my_range = tqdm.trange + + batch_size = 1024 + vec = [] + for i in my_range(0, len(sentences), batch_size): + start_idx = i + stop_idx = min(i+batch_size, len(sentences)) + data["queries"] = sentences[start_idx:stop_idx] + response = self.session.post(embedding_url, headers=headers, params=params, data=json.dumps(data), timeout=600) + new_vec = response.json() + vec += new_vec + return vec + + + def encode( + self, + sentences: list[str], + task_name: str, + prompt_type: PromptType | None = None, + **kwargs, + ) -> np.ndarray: + + embedding_url = "http://127.0.0.1:5000/embedding" + + if prompt_type == None: + template = "document" + elif prompt_type == PromptType.query: + template = "query" + elif prompt_type == PromptType.document: + template = "document" + else: + raise Exception("Error: prompt_type") + + all_embeddings = [] + # all_texts = [] + # for batch in inputs: + # all_texts += batch["text"] + # embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template) + + # all_embeddings += embeddings + all_embeddings = self.get_simplexity_query2vec_results(sentences, embedding_url, self.model, template) + + + return numpy.array(all_embeddings) + + + + +def is_dataset_cached(dataset_name): + dataset_dir_prefix = dataset_name.replace("/", "__") + return any(dataset_dir_prefix in folder for folder in os.listdir(HF_DATASETS_CACHE)) + + +def evaluate(): + # model_name = "Qwen3-Embedding-0.6B" + model_name = "llama-embed-nemotron-8b" + # model_name = "embeddinggemma-300m" + model = CustomModel(model_name) + + file_path = os.path.dirname(__file__) + # model = mteb.get_model(model_name) + # model = SentenceTransformer(model_name) + # model.model_card_data.model_name = model_name + # model.mteb_model_meta.name = model_name + # tasks = mteb.get_tasks(tasks=["Banking77Classification"]) + fas_benchmark = mteb.get_benchmark("MTEB(fas, v2)") + # benchmark = mteb.get_benchmark("MTEB(eng, v2)") + # benchmark[0].metadata.task_list + + # tasks = mteb.get_tasks(tasks=["Banking77Classification"]) + # tasks[0].metadata.task_list + + # cache = mteb.cache.ResultCache(cache_path=file_path + "/.cache") + + # for i in range(len(benchmark)): + # dataset_conf = benchmark[i].metadata_dict["dataset"] + # # if is_dataset_cached(dataset_conf["path"]) == True: + # # continue + # dataset = load_dataset( + # dataset_conf["path"], + # revision=dataset_conf["revision"] + # ) + + # benchmarks = [fas_benchmark[i] for i in range(len(fas_benchmark)) if fas_benchmark[i].metadata_dict["name"] not in ["DigikalamagClassification", "DigikalamagClustering", + # "MIRACLReranking", "PersianWebDocumentRetrieval"]] + # benchmarks = [fas_benchmark[i] for i in range(len(fas_benchmark)) if fas_benchmark[i].metadata_dict["name"] in ["ArguAna-Fa.v2"]] + benchmarks = [fas_benchmark[i] for i in range(len(fas_benchmark)) if fas_benchmark[i].metadata_dict["name"] in ["ArguAna-Fa.v2", "SCIDOCS-Fa.v2"]] + + evaluation = mteb.MTEB(tasks=benchmarks) + results = evaluation.run(model, output_folder=file_path + "/results/" + model_name) + + # for benchmark in benchmarks: + # try: + # evaluation = mteb.MTEB(tasks=[benchmark]) + # # results = evaluation.run(model, output_folder=file_path + "/results/Qwen3-Embedding-4B", proxies=proxies) + # results = evaluation.run(model, output_folder=file_path + "/results/Qwen3-Embedding-0.6B") + # except: + # print("________________________") + # print("Error : " + str(benchmark.metadata_dict["name"])) + # results = mteb.evaluate(model, tasks=benchmark, cache=cache) + + print("results = " + str(results)) + +def main(): + # get_results() + evaluate() + + +if __name__ == "__main__": + main() \ No newline at end of file