bug fix in evaulation
This commit is contained in:
parent
c5e67e5a3c
commit
e9e25a7704
@ -20,31 +20,40 @@ import numpy as np
|
||||
|
||||
|
||||
class CustomModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.model = model
|
||||
|
||||
|
||||
def get_simplexity_query2vec_results(self, sentences, embedding_url, model, template):
|
||||
params = {}
|
||||
params["model"] = model
|
||||
params["template"] = template
|
||||
def get_embedding(self, sentece, prompt_name):
|
||||
embedding_url = "http://127.0.0.1:5010/embed"
|
||||
headers = {"accept": "application/json"}
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
data = {}
|
||||
data["inputs"] = sentece
|
||||
data["normalize"] = True
|
||||
data["prompt_name"] = prompt_name
|
||||
data["truncate"] = False
|
||||
data["truncation_direction"] = "Right"
|
||||
|
||||
response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_simplexity_query2vec_results(self, sentences, template):
|
||||
|
||||
if len(sentences) < 2000:
|
||||
my_range = range
|
||||
else:
|
||||
my_range = tqdm.trange
|
||||
|
||||
batch_size = 1024
|
||||
batch_size = 64
|
||||
vec = []
|
||||
for i in my_range(0, len(sentences), batch_size):
|
||||
start_idx = i
|
||||
stop_idx = min(i+batch_size, len(sentences))
|
||||
data["queries"] = sentences[start_idx:stop_idx]
|
||||
response = self.session.post(embedding_url, headers=headers, params=params, data=json.dumps(data), timeout=600)
|
||||
new_vec = response.json()
|
||||
new_vec = self.get_embedding(sentences[start_idx:stop_idx], template)
|
||||
vec += new_vec
|
||||
return vec
|
||||
|
||||
@ -56,8 +65,6 @@ class CustomModel:
|
||||
prompt_type: PromptType | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
|
||||
embedding_url = "http://127.0.0.1:5015/embedding"
|
||||
|
||||
if prompt_type == None:
|
||||
template = "document"
|
||||
@ -75,7 +82,7 @@ class CustomModel:
|
||||
# embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template)
|
||||
|
||||
# all_embeddings += embeddings
|
||||
all_embeddings = self.get_simplexity_query2vec_results(sentences, embedding_url, self.model, template)
|
||||
all_embeddings = self.get_simplexity_query2vec_results(sentences, template)
|
||||
|
||||
|
||||
return numpy.array(all_embeddings)
|
||||
@ -102,12 +109,11 @@ def recall_at_k(index, query_embeddings, ground_truth, all_docs, k=10):
|
||||
return hits / len(ground_truth)
|
||||
|
||||
def evaluate():
|
||||
model_name = "Qwen3-Embedding-0.6B"
|
||||
# model_name = "KaLM-embedding-multilingual-mini-instruct-v2.5"
|
||||
# model_name = "KaLM-Embedding-Gemma3-12B-2511"
|
||||
# model_name = "llama-embed-nemotron-8b"
|
||||
# model_name = "embeddinggemma-300m"
|
||||
model = CustomModel(model_name)
|
||||
model = CustomModel()
|
||||
|
||||
file_path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
@ -96,6 +96,8 @@ def is_dataset_cached(dataset_name):
|
||||
def evaluate():
|
||||
model = CustomModel()
|
||||
|
||||
model_name = "Qwen3-Embedding-0.6B"
|
||||
|
||||
file_path = os.path.dirname(__file__)
|
||||
# model = mteb.get_model(model_name)
|
||||
# model = SentenceTransformer(model_name)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user