From 8eebc192e052343caae38b93b609f3d2062d99ec Mon Sep 17 00:00:00 2001 From: SFirouzi Date: Thu, 12 Mar 2026 13:40:10 +0330 Subject: [PATCH] add bge --- config/base.py | 4 +++- main.py | 6 ------ models/embedder_bge_train.py | 18 ++++++++++++++++++ models/embedder_gemma.py | 7 +++++-- models/embedder_gemma_train.py | 7 +++++-- src/serve_embed.py | 27 ++++++++++++++++++--------- 6 files changed, 49 insertions(+), 20 deletions(-) delete mode 100644 main.py create mode 100644 models/embedder_bge_train.py diff --git a/config/base.py b/config/base.py index 5fa663e..aef7f84 100644 --- a/config/base.py +++ b/config/base.py @@ -4,4 +4,6 @@ from dotenv import load_dotenv load_dotenv() GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH") -GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH") \ No newline at end of file +GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH") +BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH") +BGE_LORA_PATH = os.getenv("BGE_LORA_PATH") \ No newline at end of file diff --git a/main.py b/main.py deleted file mode 100644 index e24b934..0000000 --- a/main.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from serve-embed!") - - -if __name__ == "__main__": - main() diff --git a/models/embedder_bge_train.py b/models/embedder_bge_train.py new file mode 100644 index 0000000..e1dd03d --- /dev/null +++ b/models/embedder_bge_train.py @@ -0,0 +1,18 @@ +from sentence_transformers import SentenceTransformer +import requests + + +class TextEmbedderBgeTrain: + def __init__(self, model_path, lora_path): + self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0") + self.model.load_adapter(lora_path) + + + def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]: + """ + Embed texts using the model. + """ + if query: + return self.model.encode_query(texts) + else: + return self.model.encode_document(texts) \ No newline at end of file diff --git a/models/embedder_gemma.py b/models/embedder_gemma.py index 35fa995..4015483 100644 --- a/models/embedder_gemma.py +++ b/models/embedder_gemma.py @@ -6,8 +6,11 @@ class TextEmbedderGemma: def __init__(self, model_path): self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0") - def embed_texts(self, texts:list[str])->list[list[float]]: + def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]: """ Embed texts using the model. """ - return self.model.encode(texts) \ No newline at end of file + if query: + return self.model.encode_query(texts) + else: + return self.model.encode_document(texts) \ No newline at end of file diff --git a/models/embedder_gemma_train.py b/models/embedder_gemma_train.py index a6a0a41..44e8675 100644 --- a/models/embedder_gemma_train.py +++ b/models/embedder_gemma_train.py @@ -8,8 +8,11 @@ class TextEmbedderGemmaTrain: self.model.load_adapter(lora_path) - def embed_texts(self, texts:list[str])->list[list[float]]: + def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]: """ Embed texts using the model. """ - return self.model.encode(texts) \ No newline at end of file + if query: + return self.model.encode_query(texts) + else: + return self.model.encode_document(texts) \ No newline at end of file diff --git a/src/serve_embed.py b/src/serve_embed.py index 4c2b452..5173147 100644 --- a/src/serve_embed.py +++ b/src/serve_embed.py @@ -5,40 +5,49 @@ from pydantic import BaseModel from models.embedder_gemma import TextEmbedderGemma from models.embedder_gemma_train import TextEmbedderGemmaTrain -from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH +from models.embedder_bge_train import TextEmbedderBgeTrain +from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH, BGE_MODEL_PATH, BGE_LORA_PATH app = FastAPI() @app.on_event("startup") def load_models(): - global embedder, embedder_train - embedder = TextEmbedderGemma(GEMMA_MODEL_PATH) - embedder_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH) + global embed_gemma, embed_gemma_train, embed_bge_train + embed_gemma = TextEmbedderGemma(GEMMA_MODEL_PATH) + embed_gemma_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH) + embed_bge_train = TextEmbedderBgeTrain(BGE_MODEL_PATH, BGE_LORA_PATH) print("Models loaded successfully") class EmbedRequest(BaseModel): model: str = "gemma" input: list[str] | str + query: bool = False -@app.post("/embed_gemma") +@app.post("/embed_texts") def embed_gemma(request: EmbedRequest): """ - Embed texts using the model + Embed texts using the model. + Args: - request: EmbedRequest + *model : can be from these models : ["gemma", "gemma_train", "bge_train"] + *input : it is a list of texts + *query : your text can be query or passage. if it is query set this to true. Returns: data: list[dict] """ texts = request.input if isinstance(request.input, list) else [request.input] if request.model == "gemma": - embeddings = embedder.embed_texts(texts) + embeddings = embed_gemma.embed_texts(texts, request.query) elif request.model == "gemma_train": - embeddings = embedder_train.embed_texts(texts) + embeddings = embed_gemma_train.embed_texts(texts, request.query) + + elif request.model == "bge_train": + embeddings = embed_bge_train.embed_texts(texts, request.query) else: raise HTTPException(status_code=400, detail="Invalid model")