Compare commits
No commits in common. "0ead53ec05dfd6c8cd84329834ae15405bb2f0fa" and "0ec2c4d8ce6a3ad49e9442cb39e37c805d14a03b" have entirely different histories.
0ead53ec05
...
0ec2c4d8ce
@ -1,146 +0,0 @@
|
|||||||
import mteb
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
import tqdm
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from mteb.encoder_interface import PromptType
|
|
||||||
from typing import Any
|
|
||||||
# from mteb.abstasks.task_metadata import TaskMetadata
|
|
||||||
# from mteb.models.models_protocols import EncoderProtocol
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from datasets import load_dataset
|
|
||||||
from datasets.config import HF_DATASETS_CACHE
|
|
||||||
from huggingface_hub.utils import get_session
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CustomModel:
|
|
||||||
def __init__(self, model):
|
|
||||||
self.session = requests.Session()
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
|
|
||||||
def get_simplexity_query2vec_results(self, sentences, embedding_url, model, template):
|
|
||||||
params = {}
|
|
||||||
params["model"] = model
|
|
||||||
params["template"] = template
|
|
||||||
headers = {"accept": "application/json"}
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
if len(sentences) < 2000:
|
|
||||||
my_range = range
|
|
||||||
else:
|
|
||||||
my_range = tqdm.trange
|
|
||||||
|
|
||||||
batch_size = 1024
|
|
||||||
vec = []
|
|
||||||
for i in my_range(0, len(sentences), batch_size):
|
|
||||||
start_idx = i
|
|
||||||
stop_idx = min(i+batch_size, len(sentences))
|
|
||||||
data["queries"] = sentences[start_idx:stop_idx]
|
|
||||||
response = self.session.post(embedding_url, headers=headers, params=params, data=json.dumps(data), timeout=600)
|
|
||||||
new_vec = response.json()
|
|
||||||
vec += new_vec
|
|
||||||
return vec
|
|
||||||
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
sentences: list[str],
|
|
||||||
task_name: str,
|
|
||||||
prompt_type: PromptType | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> np.ndarray:
|
|
||||||
|
|
||||||
embedding_url = "http://127.0.0.1:5000/embedding"
|
|
||||||
|
|
||||||
if prompt_type == None:
|
|
||||||
template = "document"
|
|
||||||
elif prompt_type == PromptType.query:
|
|
||||||
template = "query"
|
|
||||||
elif prompt_type == PromptType.document:
|
|
||||||
template = "document"
|
|
||||||
else:
|
|
||||||
raise Exception("Error: prompt_type")
|
|
||||||
|
|
||||||
all_embeddings = []
|
|
||||||
# all_texts = []
|
|
||||||
# for batch in inputs:
|
|
||||||
# all_texts += batch["text"]
|
|
||||||
# embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template)
|
|
||||||
|
|
||||||
# all_embeddings += embeddings
|
|
||||||
all_embeddings = self.get_simplexity_query2vec_results(sentences, embedding_url, self.model, template)
|
|
||||||
|
|
||||||
|
|
||||||
return numpy.array(all_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_dataset_cached(dataset_name):
|
|
||||||
dataset_dir_prefix = dataset_name.replace("/", "__")
|
|
||||||
return any(dataset_dir_prefix in folder for folder in os.listdir(HF_DATASETS_CACHE))
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate():
|
|
||||||
# model_name = "Qwen3-Embedding-0.6B"
|
|
||||||
model_name = "llama-embed-nemotron-8b"
|
|
||||||
# model_name = "embeddinggemma-300m"
|
|
||||||
model = CustomModel(model_name)
|
|
||||||
|
|
||||||
file_path = os.path.dirname(__file__)
|
|
||||||
# model = mteb.get_model(model_name)
|
|
||||||
# model = SentenceTransformer(model_name)
|
|
||||||
# model.model_card_data.model_name = model_name
|
|
||||||
# model.mteb_model_meta.name = model_name
|
|
||||||
# tasks = mteb.get_tasks(tasks=["Banking77Classification"])
|
|
||||||
fas_benchmark = mteb.get_benchmark("MTEB(fas, v2)")
|
|
||||||
# benchmark = mteb.get_benchmark("MTEB(eng, v2)")
|
|
||||||
# benchmark[0].metadata.task_list
|
|
||||||
|
|
||||||
# tasks = mteb.get_tasks(tasks=["Banking77Classification"])
|
|
||||||
# tasks[0].metadata.task_list
|
|
||||||
|
|
||||||
# cache = mteb.cache.ResultCache(cache_path=file_path + "/.cache")
|
|
||||||
|
|
||||||
# for i in range(len(benchmark)):
|
|
||||||
# dataset_conf = benchmark[i].metadata_dict["dataset"]
|
|
||||||
# # if is_dataset_cached(dataset_conf["path"]) == True:
|
|
||||||
# # continue
|
|
||||||
# dataset = load_dataset(
|
|
||||||
# dataset_conf["path"],
|
|
||||||
# revision=dataset_conf["revision"]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# benchmarks = [fas_benchmark[i] for i in range(len(fas_benchmark)) if fas_benchmark[i].metadata_dict["name"] not in ["DigikalamagClassification", "DigikalamagClustering",
|
|
||||||
# "MIRACLReranking", "PersianWebDocumentRetrieval"]]
|
|
||||||
# benchmarks = [fas_benchmark[i] for i in range(len(fas_benchmark)) if fas_benchmark[i].metadata_dict["name"] in ["ArguAna-Fa.v2"]]
|
|
||||||
benchmarks = [fas_benchmark[i] for i in range(len(fas_benchmark)) if fas_benchmark[i].metadata_dict["name"] in ["ArguAna-Fa.v2", "SCIDOCS-Fa.v2"]]
|
|
||||||
|
|
||||||
evaluation = mteb.MTEB(tasks=benchmarks)
|
|
||||||
results = evaluation.run(model, output_folder=file_path + "/results/" + model_name)
|
|
||||||
|
|
||||||
# for benchmark in benchmarks:
|
|
||||||
# try:
|
|
||||||
# evaluation = mteb.MTEB(tasks=[benchmark])
|
|
||||||
# # results = evaluation.run(model, output_folder=file_path + "/results/Qwen3-Embedding-4B", proxies=proxies)
|
|
||||||
# results = evaluation.run(model, output_folder=file_path + "/results/Qwen3-Embedding-0.6B")
|
|
||||||
# except:
|
|
||||||
# print("________________________")
|
|
||||||
# print("Error : " + str(benchmark.metadata_dict["name"]))
|
|
||||||
# results = mteb.evaluate(model, tasks=benchmark, cache=cache)
|
|
||||||
|
|
||||||
print("results = " + str(results))
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# get_results()
|
|
||||||
evaluate()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Loading…
x
Reference in New Issue
Block a user