add serve

This commit is contained in:
a.hediehloo 2025-12-28 09:07:48 +00:00
parent 59831896f1
commit c5e67e5a3c
3 changed files with 69 additions and 22 deletions

View File

@ -19,31 +19,40 @@ import numpy
class CustomModel: class CustomModel:
def __init__(self, model): def __init__(self):
self.session = requests.Session() self.session = requests.Session()
self.model = model
def get_simplexity_query2vec_results(self, sentences, embedding_url, model, template): def get_embedding(self, sentece, prompt_name):
params = {} embedding_url = "http://127.0.0.1:5010/embed"
params["model"] = model
params["template"] = template
headers = {"accept": "application/json"} headers = {"accept": "application/json"}
headers["Content-Type"] = "application/json"
data = {} data = {}
data["inputs"] = sentece
data["normalize"] = True
data["prompt_name"] = prompt_name
data["truncate"] = False
data["truncation_direction"] = "Right"
response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600)
return response.json()
def get_simplexity_query2vec_results(self, sentences, template):
if len(sentences) < 2000: if len(sentences) < 2000:
my_range = range my_range = range
else: else:
my_range = tqdm.trange my_range = tqdm.trange
batch_size = 1024 batch_size = 64
vec = [] vec = []
for i in my_range(0, len(sentences), batch_size): for i in my_range(0, len(sentences), batch_size):
start_idx = i start_idx = i
stop_idx = min(i+batch_size, len(sentences)) stop_idx = min(i+batch_size, len(sentences))
data["queries"] = sentences[start_idx:stop_idx] new_vec = self.get_embedding(sentences[start_idx:stop_idx], template)
response = self.session.post(embedding_url, headers=headers, params=params, data=json.dumps(data), timeout=600)
new_vec = response.json()
vec += new_vec vec += new_vec
return vec return vec
@ -55,9 +64,6 @@ class CustomModel:
prompt_type: PromptType | None = None, prompt_type: PromptType | None = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
embedding_url = "http://127.0.0.1:5015/embedding"
if prompt_type == None: if prompt_type == None:
template = "document" template = "document"
elif prompt_type == PromptType.query: elif prompt_type == PromptType.query:
@ -74,7 +80,7 @@ class CustomModel:
# embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template) # embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template)
# all_embeddings += embeddings # all_embeddings += embeddings
all_embeddings = self.get_simplexity_query2vec_results(sentences, embedding_url, self.model, template) all_embeddings = self.get_simplexity_query2vec_results(sentences, template)
return numpy.array(all_embeddings) return numpy.array(all_embeddings)
@ -88,12 +94,7 @@ def is_dataset_cached(dataset_name):
def evaluate(): def evaluate():
model_name = "Qwen3-Embedding-0.6B" model = CustomModel()
# model_name = "KaLM-embedding-multilingual-mini-instruct-v2.5"
# model_name = "KaLM-Embedding-Gemma3-12B-2511"
# model_name = "llama-embed-nemotron-8b"
# model_name = "embeddinggemma-300m"
model = CustomModel(model_name)
file_path = os.path.dirname(__file__) file_path = os.path.dirname(__file__)
# model = mteb.get_model(model_name) # model = mteb.get_model(model_name)
@ -140,7 +141,6 @@ def evaluate():
print("results = " + str(results)) print("results = " + str(results))
def main(): def main():
# get_results()
evaluate() evaluate()

View File

@ -0,0 +1,27 @@
version: "3.8"
services:
query_2_vec_Qwen3-Embedding-0.6B:
image: ghcr.io/huggingface/text-embeddings-inference:1.8
container_name: query_2_vec_Qwen3-Embedding-0.6B
restart: unless-stopped
entrypoint: /bin/bash
command: |
-c "bash /app/start_vllm.sh"
shm_size: '45600m'
ports:
- "5010:8080"
volumes:
- .:/app
#- ./../../../../data:/app/data
- /home/hediehloo/code/embedding/embedding_model/train/qwen/output/v28-20251223-054407/merged-checkpoint-3707:/app/data/models/Qwen3-Embedding-0.6B/model
# - ./../../../logging_config.json:/app/logging_config.json
# - ./../../../src:/app/src
# - ./../../../logs:/app/logs
deploy:
resources:
reservations:
devices:
- capabilities: [gpu]
stdin_open: true
tty: true

20
serve/qwen/start_vllm.sh Normal file
View File

@ -0,0 +1,20 @@
export CUDA_VISIBLE_DEVICES=0
export PYTHONPATH="/app"
# export VLLM_LOGGING_CONFIG_PATH=/app/logging_config.json
export TZ="Asia/Tehran"
# mkdir -p /app/logs
# sleep 200
text-embeddings-router \
--model-id /app/data/models/Qwen3-Embedding-0.6B/model \
--port 8080 \
--dtype float16 \
--max-client-batch-size 1024 \
--max-concurrent-requests 1024 \
--max-batch-requests 1024 \
--max-batch-tokens 32768