add batch_size

This commit is contained in:
SFirouzi 2026-03-14 11:23:42 +03:30
parent 3a38099749
commit 6a90ff0a7e
5 changed files with 67 additions and 26 deletions

View File

@ -6,4 +6,6 @@ load_dotenv()
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH")
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")
BATCH_SIZE = 250

View File

@ -1,5 +1,7 @@
from sentence_transformers import SentenceTransformer
import requests
from config.base import BATCH_SIZE
import torch
class TextEmbedderBgeTrain:
@ -12,7 +14,17 @@ class TextEmbedderBgeTrain:
"""
Embed texts using the model.
"""
if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)
all_embeddings = []
for i in range(0, len(texts), BATCH_SIZE):
batch_texts = texts[i:i+BATCH_SIZE]
if query:
embeddings = self.model.encode_query(batch_texts)
else:
embeddings = self.model.encode_document(batch_texts)
all_embeddings.extend(embeddings)
torch.cuda.empty_cache()
return all_embeddings

View File

@ -1,6 +1,7 @@
from sentence_transformers import SentenceTransformer
import requests
import torch
from config.base import BATCH_SIZE
class TextEmbedderGemma:
def __init__(self, model_path):
@ -10,7 +11,17 @@ class TextEmbedderGemma:
"""
Embed texts using the model.
"""
if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)
all_embeddings = []
for i in range(0, len(texts), BATCH_SIZE):
batch_texts = texts[i:i+BATCH_SIZE]
if query:
embeddings = self.model.encode_query(batch_texts)
else:
embeddings = self.model.encode_document(batch_texts)
all_embeddings.extend(embeddings)
torch.cuda.empty_cache()
return all_embeddings

View File

@ -1,5 +1,7 @@
from sentence_transformers import SentenceTransformer
import requests
import torch
from config.base import BATCH_SIZE
class TextEmbedderGemmaTrain:
@ -12,7 +14,17 @@ class TextEmbedderGemmaTrain:
"""
Embed texts using the model.
"""
if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)
all_embeddings = []
for i in range(0, len(texts), BATCH_SIZE):
batch_texts = texts[i:i+BATCH_SIZE]
if query:
embeddings = self.model.encode_query(batch_texts)
else:
embeddings = self.model.encode_document(batch_texts)
all_embeddings.extend(embeddings)
torch.cuda.empty_cache()
return all_embeddings

View File

@ -39,20 +39,24 @@ def embed_gemma(request: EmbedRequest):
Returns:
data: list[dict]
"""
texts = request.input if isinstance(request.input, list) else [request.input]
if request.model == "gemma":
embeddings = embed_gemma.embed_texts(texts, request.query)
elif request.model == "gemma_train":
embeddings = embed_gemma_train.embed_texts(texts, request.query)
try:
texts = request.input if isinstance(request.input, list) else [request.input]
if request.model == "gemma":
embeddings = embed_gemma.embed_texts(texts, request.query)
elif request.model == "gemma_train":
embeddings = embed_gemma_train.embed_texts(texts, request.query)
elif request.model == "bge_train":
embeddings = embed_bge_train.embed_texts(texts, request.query)
else:
raise HTTPException(status_code=400, detail="Invalid model")
elif request.model == "bge_train":
embeddings = embed_bge_train.embed_texts(texts, request.query)
else:
raise HTTPException(status_code=400, detail="Invalid model")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
torch.cuda.empty_cache()
return {"data": [{"embedding": emb.tolist()} for emb in embeddings]}