add batch_size
This commit is contained in:
parent
3a38099749
commit
6a90ff0a7e
@ -7,3 +7,5 @@ GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
|
|||||||
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
||||||
BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH")
|
BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH")
|
||||||
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")
|
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")
|
||||||
|
|
||||||
|
BATCH_SIZE = 250
|
||||||
@ -1,5 +1,7 @@
|
|||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import requests
|
import requests
|
||||||
|
from config.base import BATCH_SIZE
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TextEmbedderBgeTrain:
|
class TextEmbedderBgeTrain:
|
||||||
@ -12,7 +14,17 @@ class TextEmbedderBgeTrain:
|
|||||||
"""
|
"""
|
||||||
Embed texts using the model.
|
Embed texts using the model.
|
||||||
"""
|
"""
|
||||||
if query:
|
all_embeddings = []
|
||||||
return self.model.encode_query(texts)
|
for i in range(0, len(texts), BATCH_SIZE):
|
||||||
else:
|
batch_texts = texts[i:i+BATCH_SIZE]
|
||||||
return self.model.encode_document(texts)
|
|
||||||
|
if query:
|
||||||
|
embeddings = self.model.encode_query(batch_texts)
|
||||||
|
else:
|
||||||
|
embeddings = self.model.encode_document(batch_texts)
|
||||||
|
|
||||||
|
all_embeddings.extend(embeddings)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
@ -1,6 +1,7 @@
|
|||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
from config.base import BATCH_SIZE
|
||||||
|
|
||||||
class TextEmbedderGemma:
|
class TextEmbedderGemma:
|
||||||
def __init__(self, model_path):
|
def __init__(self, model_path):
|
||||||
@ -10,7 +11,17 @@ class TextEmbedderGemma:
|
|||||||
"""
|
"""
|
||||||
Embed texts using the model.
|
Embed texts using the model.
|
||||||
"""
|
"""
|
||||||
if query:
|
all_embeddings = []
|
||||||
return self.model.encode_query(texts)
|
for i in range(0, len(texts), BATCH_SIZE):
|
||||||
else:
|
batch_texts = texts[i:i+BATCH_SIZE]
|
||||||
return self.model.encode_document(texts)
|
|
||||||
|
if query:
|
||||||
|
embeddings = self.model.encode_query(batch_texts)
|
||||||
|
else:
|
||||||
|
embeddings = self.model.encode_document(batch_texts)
|
||||||
|
|
||||||
|
all_embeddings.extend(embeddings)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
@ -1,5 +1,7 @@
|
|||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
from config.base import BATCH_SIZE
|
||||||
|
|
||||||
|
|
||||||
class TextEmbedderGemmaTrain:
|
class TextEmbedderGemmaTrain:
|
||||||
@ -12,7 +14,17 @@ class TextEmbedderGemmaTrain:
|
|||||||
"""
|
"""
|
||||||
Embed texts using the model.
|
Embed texts using the model.
|
||||||
"""
|
"""
|
||||||
if query:
|
all_embeddings = []
|
||||||
return self.model.encode_query(texts)
|
for i in range(0, len(texts), BATCH_SIZE):
|
||||||
else:
|
batch_texts = texts[i:i+BATCH_SIZE]
|
||||||
return self.model.encode_document(texts)
|
|
||||||
|
if query:
|
||||||
|
embeddings = self.model.encode_query(batch_texts)
|
||||||
|
else:
|
||||||
|
embeddings = self.model.encode_document(batch_texts)
|
||||||
|
|
||||||
|
all_embeddings.extend(embeddings)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
@ -39,19 +39,23 @@ def embed_gemma(request: EmbedRequest):
|
|||||||
Returns:
|
Returns:
|
||||||
data: list[dict]
|
data: list[dict]
|
||||||
"""
|
"""
|
||||||
texts = request.input if isinstance(request.input, list) else [request.input]
|
try:
|
||||||
|
texts = request.input if isinstance(request.input, list) else [request.input]
|
||||||
|
|
||||||
if request.model == "gemma":
|
if request.model == "gemma":
|
||||||
embeddings = embed_gemma.embed_texts(texts, request.query)
|
embeddings = embed_gemma.embed_texts(texts, request.query)
|
||||||
|
|
||||||
elif request.model == "gemma_train":
|
elif request.model == "gemma_train":
|
||||||
embeddings = embed_gemma_train.embed_texts(texts, request.query)
|
embeddings = embed_gemma_train.embed_texts(texts, request.query)
|
||||||
|
|
||||||
elif request.model == "bge_train":
|
elif request.model == "bge_train":
|
||||||
embeddings = embed_bge_train.embed_texts(texts, request.query)
|
embeddings = embed_bge_train.embed_texts(texts, request.query)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=400, detail="Invalid model")
|
raise HTTPException(status_code=400, detail="Invalid model")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user