This commit is contained in:
SFirouzi 2026-03-12 13:40:10 +03:30
parent 3dd659fb7e
commit 8eebc192e0
6 changed files with 49 additions and 20 deletions

View File

@ -5,3 +5,5 @@ load_dotenv()
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH") GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH") GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH")
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")

View File

@ -1,6 +0,0 @@
def main():
print("Hello from serve-embed!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,18 @@
from sentence_transformers import SentenceTransformer
import requests
class TextEmbedderBgeTrain:
def __init__(self, model_path, lora_path):
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
self.model.load_adapter(lora_path)
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
"""
Embed texts using the model.
"""
if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)

View File

@ -6,8 +6,11 @@ class TextEmbedderGemma:
def __init__(self, model_path): def __init__(self, model_path):
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0") self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
def embed_texts(self, texts:list[str])->list[list[float]]: def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
""" """
Embed texts using the model. Embed texts using the model.
""" """
return self.model.encode(texts) if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)

View File

@ -8,8 +8,11 @@ class TextEmbedderGemmaTrain:
self.model.load_adapter(lora_path) self.model.load_adapter(lora_path)
def embed_texts(self, texts:list[str])->list[list[float]]: def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
""" """
Embed texts using the model. Embed texts using the model.
""" """
return self.model.encode(texts) if query:
return self.model.encode_query(texts)
else:
return self.model.encode_document(texts)

View File

@ -5,40 +5,49 @@ from pydantic import BaseModel
from models.embedder_gemma import TextEmbedderGemma from models.embedder_gemma import TextEmbedderGemma
from models.embedder_gemma_train import TextEmbedderGemmaTrain from models.embedder_gemma_train import TextEmbedderGemmaTrain
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH from models.embedder_bge_train import TextEmbedderBgeTrain
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH, BGE_MODEL_PATH, BGE_LORA_PATH
app = FastAPI() app = FastAPI()
@app.on_event("startup") @app.on_event("startup")
def load_models(): def load_models():
global embedder, embedder_train global embed_gemma, embed_gemma_train, embed_bge_train
embedder = TextEmbedderGemma(GEMMA_MODEL_PATH) embed_gemma = TextEmbedderGemma(GEMMA_MODEL_PATH)
embedder_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH) embed_gemma_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
embed_bge_train = TextEmbedderBgeTrain(BGE_MODEL_PATH, BGE_LORA_PATH)
print("Models loaded successfully") print("Models loaded successfully")
class EmbedRequest(BaseModel): class EmbedRequest(BaseModel):
model: str = "gemma" model: str = "gemma"
input: list[str] | str input: list[str] | str
query: bool = False
@app.post("/embed_gemma") @app.post("/embed_texts")
def embed_gemma(request: EmbedRequest): def embed_gemma(request: EmbedRequest):
""" """
Embed texts using the model Embed texts using the model.
Args: Args:
request: EmbedRequest *model : can be from these models : ["gemma", "gemma_train", "bge_train"]
*input : it is a list of texts
*query : your text can be query or passage. if it is query set this to true.
Returns: Returns:
data: list[dict] data: list[dict]
""" """
texts = request.input if isinstance(request.input, list) else [request.input] texts = request.input if isinstance(request.input, list) else [request.input]
if request.model == "gemma": if request.model == "gemma":
embeddings = embedder.embed_texts(texts) embeddings = embed_gemma.embed_texts(texts, request.query)
elif request.model == "gemma_train": elif request.model == "gemma_train":
embeddings = embedder_train.embed_texts(texts) embeddings = embed_gemma_train.embed_texts(texts, request.query)
elif request.model == "bge_train":
embeddings = embed_bge_train.embed_texts(texts, request.query)
else: else:
raise HTTPException(status_code=400, detail="Invalid model") raise HTTPException(status_code=400, detail="Invalid model")