This commit is contained in:
SFirouzi 2026-03-12 13:40:10 +03:30
parent 3dd659fb7e
commit 8eebc192e0
6 changed files with 49 additions and 20 deletions

View File

@ -4,4 +4,6 @@ from dotenv import load_dotenv
load_dotenv()
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH")
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")

View File

@ -1,6 +0,0 @@
def main():
print("Hello from serve-embed!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,18 @@
from sentence_transformers import SentenceTransformer
import requests
class TextEmbedderBgeTrain:
def __init__(self, model_path, lora_path):
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
self.model.load_adapter(lora_path)
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
"""
Embed texts using the model.
"""
if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)

View File

@ -6,8 +6,11 @@ class TextEmbedderGemma:
def __init__(self, model_path):
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
def embed_texts(self, texts:list[str])->list[list[float]]:
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
"""
Embed texts using the model.
"""
return self.model.encode(texts)
if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)

View File

@ -8,8 +8,11 @@ class TextEmbedderGemmaTrain:
self.model.load_adapter(lora_path)
def embed_texts(self, texts:list[str])->list[list[float]]:
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
"""
Embed texts using the model.
"""
return self.model.encode(texts)
if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)

View File

@ -5,40 +5,49 @@ from pydantic import BaseModel
from models.embedder_gemma import TextEmbedderGemma
from models.embedder_gemma_train import TextEmbedderGemmaTrain
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH
from models.embedder_bge_train import TextEmbedderBgeTrain
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH, BGE_MODEL_PATH, BGE_LORA_PATH
app = FastAPI()
@app.on_event("startup")
def load_models():
global embedder, embedder_train
embedder = TextEmbedderGemma(GEMMA_MODEL_PATH)
embedder_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
global embed_gemma, embed_gemma_train, embed_bge_train
embed_gemma = TextEmbedderGemma(GEMMA_MODEL_PATH)
embed_gemma_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
embed_bge_train = TextEmbedderBgeTrain(BGE_MODEL_PATH, BGE_LORA_PATH)
print("Models loaded successfully")
class EmbedRequest(BaseModel):
model: str = "gemma"
input: list[str] | str
query: bool = False
@app.post("/embed_gemma")
@app.post("/embed_texts")
def embed_gemma(request: EmbedRequest):
"""
Embed texts using the model
Embed texts using the model.
Args:
request: EmbedRequest
*model : can be from these models : ["gemma", "gemma_train", "bge_train"]
*input : it is a list of texts
*query : your text can be query or passage. if it is query set this to true.
Returns:
data: list[dict]
"""
texts = request.input if isinstance(request.input, list) else [request.input]
if request.model == "gemma":
embeddings = embedder.embed_texts(texts)
embeddings = embed_gemma.embed_texts(texts, request.query)
elif request.model == "gemma_train":
embeddings = embedder_train.embed_texts(texts)
embeddings = embed_gemma_train.embed_texts(texts, request.query)
elif request.model == "bge_train":
embeddings = embed_bge_train.embed_texts(texts, request.query)
else:
raise HTTPException(status_code=400, detail="Invalid model")