add batch_size
This commit is contained in:
parent
3a38099749
commit
6a90ff0a7e
@ -6,4 +6,6 @@ load_dotenv()
|
||||
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
|
||||
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
||||
BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH")
|
||||
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")
|
||||
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")
|
||||
|
||||
BATCH_SIZE = 250
|
||||
@ -1,5 +1,7 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import requests
|
||||
from config.base import BATCH_SIZE
|
||||
import torch
|
||||
|
||||
|
||||
class TextEmbedderBgeTrain:
|
||||
@ -12,7 +14,17 @@ class TextEmbedderBgeTrain:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
if query:
|
||||
return self.model.encode_query(texts)
|
||||
else:
|
||||
return self.model.encode_document(texts)
|
||||
all_embeddings = []
|
||||
for i in range(0, len(texts), BATCH_SIZE):
|
||||
batch_texts = texts[i:i+BATCH_SIZE]
|
||||
|
||||
if query:
|
||||
embeddings = self.model.encode_query(batch_texts)
|
||||
else:
|
||||
embeddings = self.model.encode_document(batch_texts)
|
||||
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return all_embeddings
|
||||
@ -1,6 +1,7 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import requests
|
||||
|
||||
import torch
|
||||
from config.base import BATCH_SIZE
|
||||
|
||||
class TextEmbedderGemma:
|
||||
def __init__(self, model_path):
|
||||
@ -10,7 +11,17 @@ class TextEmbedderGemma:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
if query:
|
||||
return self.model.encode_query(texts)
|
||||
else:
|
||||
return self.model.encode_document(texts)
|
||||
all_embeddings = []
|
||||
for i in range(0, len(texts), BATCH_SIZE):
|
||||
batch_texts = texts[i:i+BATCH_SIZE]
|
||||
|
||||
if query:
|
||||
embeddings = self.model.encode_query(batch_texts)
|
||||
else:
|
||||
embeddings = self.model.encode_document(batch_texts)
|
||||
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return all_embeddings
|
||||
@ -1,5 +1,7 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import requests
|
||||
import torch
|
||||
from config.base import BATCH_SIZE
|
||||
|
||||
|
||||
class TextEmbedderGemmaTrain:
|
||||
@ -12,7 +14,17 @@ class TextEmbedderGemmaTrain:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
if query:
|
||||
return self.model.encode_query(texts)
|
||||
else:
|
||||
return self.model.encode_document(texts)
|
||||
all_embeddings = []
|
||||
for i in range(0, len(texts), BATCH_SIZE):
|
||||
batch_texts = texts[i:i+BATCH_SIZE]
|
||||
|
||||
if query:
|
||||
embeddings = self.model.encode_query(batch_texts)
|
||||
else:
|
||||
embeddings = self.model.encode_document(batch_texts)
|
||||
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return all_embeddings
|
||||
@ -39,20 +39,24 @@ def embed_gemma(request: EmbedRequest):
|
||||
Returns:
|
||||
data: list[dict]
|
||||
"""
|
||||
texts = request.input if isinstance(request.input, list) else [request.input]
|
||||
|
||||
if request.model == "gemma":
|
||||
embeddings = embed_gemma.embed_texts(texts, request.query)
|
||||
|
||||
elif request.model == "gemma_train":
|
||||
embeddings = embed_gemma_train.embed_texts(texts, request.query)
|
||||
try:
|
||||
texts = request.input if isinstance(request.input, list) else [request.input]
|
||||
|
||||
if request.model == "gemma":
|
||||
embeddings = embed_gemma.embed_texts(texts, request.query)
|
||||
|
||||
elif request.model == "gemma_train":
|
||||
embeddings = embed_gemma_train.embed_texts(texts, request.query)
|
||||
|
||||
elif request.model == "bge_train":
|
||||
embeddings = embed_bge_train.embed_texts(texts, request.query)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid model")
|
||||
elif request.model == "bge_train":
|
||||
embeddings = embed_bge_train.embed_texts(texts, request.query)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid model")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {"data": [{"embedding": emb.tolist()} for emb in embeddings]}
|
||||
Loading…
x
Reference in New Issue
Block a user