From 6a90ff0a7ef29da78cb0df8d7bb5307b564d6412 Mon Sep 17 00:00:00 2001 From: SFirouzi Date: Sat, 14 Mar 2026 11:23:42 +0330 Subject: [PATCH] add batch_size --- config/base.py | 4 +++- models/embedder_bge_train.py | 20 ++++++++++++++++---- models/embedder_gemma.py | 21 ++++++++++++++++----- models/embedder_gemma_train.py | 20 ++++++++++++++++---- src/serve_embed.py | 28 ++++++++++++++++------------ 5 files changed, 67 insertions(+), 26 deletions(-) diff --git a/config/base.py b/config/base.py index aef7f84..8b3b161 100644 --- a/config/base.py +++ b/config/base.py @@ -6,4 +6,6 @@ load_dotenv() GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH") GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH") BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH") -BGE_LORA_PATH = os.getenv("BGE_LORA_PATH") \ No newline at end of file +BGE_LORA_PATH = os.getenv("BGE_LORA_PATH") + +BATCH_SIZE = 250 \ No newline at end of file diff --git a/models/embedder_bge_train.py b/models/embedder_bge_train.py index e1dd03d..c139c8c 100644 --- a/models/embedder_bge_train.py +++ b/models/embedder_bge_train.py @@ -1,5 +1,7 @@ from sentence_transformers import SentenceTransformer import requests +from config.base import BATCH_SIZE +import torch class TextEmbedderBgeTrain: @@ -12,7 +14,17 @@ class TextEmbedderBgeTrain: """ Embed texts using the model. """ - if query: - return self.model.encode_query(texts) - else: - return self.model.encode_document(texts) \ No newline at end of file + all_embeddings = [] + for i in range(0, len(texts), BATCH_SIZE): + batch_texts = texts[i:i+BATCH_SIZE] + + if query: + embeddings = self.model.encode_query(batch_texts) + else: + embeddings = self.model.encode_document(batch_texts) + + all_embeddings.extend(embeddings) + + torch.cuda.empty_cache() + + return all_embeddings \ No newline at end of file diff --git a/models/embedder_gemma.py b/models/embedder_gemma.py index 4015483..c47a404 100644 --- a/models/embedder_gemma.py +++ b/models/embedder_gemma.py @@ -1,6 +1,7 @@ from sentence_transformers import SentenceTransformer import requests - +import torch +from config.base import BATCH_SIZE class TextEmbedderGemma: def __init__(self, model_path): @@ -10,7 +11,17 @@ class TextEmbedderGemma: """ Embed texts using the model. """ - if query: - return self.model.encode_query(texts) - else: - return self.model.encode_document(texts) \ No newline at end of file + all_embeddings = [] + for i in range(0, len(texts), BATCH_SIZE): + batch_texts = texts[i:i+BATCH_SIZE] + + if query: + embeddings = self.model.encode_query(batch_texts) + else: + embeddings = self.model.encode_document(batch_texts) + + all_embeddings.extend(embeddings) + + torch.cuda.empty_cache() + + return all_embeddings \ No newline at end of file diff --git a/models/embedder_gemma_train.py b/models/embedder_gemma_train.py index 44e8675..8b35c82 100644 --- a/models/embedder_gemma_train.py +++ b/models/embedder_gemma_train.py @@ -1,5 +1,7 @@ from sentence_transformers import SentenceTransformer import requests +import torch +from config.base import BATCH_SIZE class TextEmbedderGemmaTrain: @@ -12,7 +14,17 @@ class TextEmbedderGemmaTrain: """ Embed texts using the model. """ - if query: - return self.model.encode_query(texts) - else: - return self.model.encode_document(texts) \ No newline at end of file + all_embeddings = [] + for i in range(0, len(texts), BATCH_SIZE): + batch_texts = texts[i:i+BATCH_SIZE] + + if query: + embeddings = self.model.encode_query(batch_texts) + else: + embeddings = self.model.encode_document(batch_texts) + + all_embeddings.extend(embeddings) + + torch.cuda.empty_cache() + + return all_embeddings \ No newline at end of file diff --git a/src/serve_embed.py b/src/serve_embed.py index f388b53..6caa30c 100644 --- a/src/serve_embed.py +++ b/src/serve_embed.py @@ -39,20 +39,24 @@ def embed_gemma(request: EmbedRequest): Returns: data: list[dict] """ - texts = request.input if isinstance(request.input, list) else [request.input] - - if request.model == "gemma": - embeddings = embed_gemma.embed_texts(texts, request.query) - - elif request.model == "gemma_train": - embeddings = embed_gemma_train.embed_texts(texts, request.query) + try: + texts = request.input if isinstance(request.input, list) else [request.input] + + if request.model == "gemma": + embeddings = embed_gemma.embed_texts(texts, request.query) + + elif request.model == "gemma_train": + embeddings = embed_gemma_train.embed_texts(texts, request.query) - elif request.model == "bge_train": - embeddings = embed_bge_train.embed_texts(texts, request.query) - - else: - raise HTTPException(status_code=400, detail="Invalid model") + elif request.model == "bge_train": + embeddings = embed_bge_train.embed_texts(texts, request.query) + + else: + raise HTTPException(status_code=400, detail="Invalid model") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + torch.cuda.empty_cache() return {"data": [{"embedding": emb.tolist()} for emb in embeddings]} \ No newline at end of file