diff --git a/evaluation/evaluate_with_religous_50000.py b/evaluation/evaluate_with_religous_50000.py index 5451609..b2a65d3 100644 --- a/evaluation/evaluate_with_religous_50000.py +++ b/evaluation/evaluate_with_religous_50000.py @@ -20,31 +20,40 @@ import numpy as np class CustomModel: - def __init__(self, model): + def __init__(self): self.session = requests.Session() - self.model = model - def get_simplexity_query2vec_results(self, sentences, embedding_url, model, template): - params = {} - params["model"] = model - params["template"] = template + def get_embedding(self, sentece, prompt_name): + embedding_url = "http://127.0.0.1:5010/embed" headers = {"accept": "application/json"} + headers["Content-Type"] = "application/json" + data = {} + data["inputs"] = sentece + data["normalize"] = True + data["prompt_name"] = prompt_name + data["truncate"] = False + data["truncation_direction"] = "Right" + + response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600) + + return response.json() + + + def get_simplexity_query2vec_results(self, sentences, template): if len(sentences) < 2000: my_range = range else: my_range = tqdm.trange - batch_size = 1024 + batch_size = 64 vec = [] for i in my_range(0, len(sentences), batch_size): start_idx = i stop_idx = min(i+batch_size, len(sentences)) - data["queries"] = sentences[start_idx:stop_idx] - response = self.session.post(embedding_url, headers=headers, params=params, data=json.dumps(data), timeout=600) - new_vec = response.json() + new_vec = self.get_embedding(sentences[start_idx:stop_idx], template) vec += new_vec return vec @@ -56,8 +65,6 @@ class CustomModel: prompt_type: PromptType | None = None, **kwargs, ) -> np.ndarray: - - embedding_url = "http://127.0.0.1:5015/embedding" if prompt_type == None: template = "document" @@ -75,7 +82,7 @@ class CustomModel: # embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template) # all_embeddings += embeddings - all_embeddings = self.get_simplexity_query2vec_results(sentences, embedding_url, self.model, template) + all_embeddings = self.get_simplexity_query2vec_results(sentences, template) return numpy.array(all_embeddings) @@ -102,12 +109,11 @@ def recall_at_k(index, query_embeddings, ground_truth, all_docs, k=10): return hits / len(ground_truth) def evaluate(): - model_name = "Qwen3-Embedding-0.6B" # model_name = "KaLM-embedding-multilingual-mini-instruct-v2.5" # model_name = "KaLM-Embedding-Gemma3-12B-2511" # model_name = "llama-embed-nemotron-8b" # model_name = "embeddinggemma-300m" - model = CustomModel(model_name) + model = CustomModel() file_path = os.path.dirname(__file__) diff --git a/evaluation/evaluation.py b/evaluation/evaluation.py index 449d139..43608c2 100644 --- a/evaluation/evaluation.py +++ b/evaluation/evaluation.py @@ -96,6 +96,8 @@ def is_dataset_cached(dataset_name): def evaluate(): model = CustomModel() + model_name = "Qwen3-Embedding-0.6B" + file_path = os.path.dirname(__file__) # model = mteb.get_model(model_name) # model = SentenceTransformer(model_name)