add bge
This commit is contained in:
parent
3dd659fb7e
commit
8eebc192e0
@ -4,4 +4,6 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
|
||||
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
||||
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
||||
BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH")
|
||||
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")
|
||||
6
main.py
6
main.py
@ -1,6 +0,0 @@
|
||||
def main():
|
||||
print("Hello from serve-embed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
18
models/embedder_bge_train.py
Normal file
18
models/embedder_bge_train.py
Normal file
@ -0,0 +1,18 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import requests
|
||||
|
||||
|
||||
class TextEmbedderBgeTrain:
|
||||
def __init__(self, model_path, lora_path):
|
||||
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
||||
self.model.load_adapter(lora_path)
|
||||
|
||||
|
||||
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
if query:
|
||||
return self.model.encode_query(texts)
|
||||
else:
|
||||
return self.model.encode_document(texts)
|
||||
@ -6,8 +6,11 @@ class TextEmbedderGemma:
|
||||
def __init__(self, model_path):
|
||||
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
||||
|
||||
def embed_texts(self, texts:list[str])->list[list[float]]:
|
||||
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
return self.model.encode(texts)
|
||||
if query:
|
||||
return self.model.encode_query(texts)
|
||||
else:
|
||||
return self.model.encode_document(texts)
|
||||
@ -8,8 +8,11 @@ class TextEmbedderGemmaTrain:
|
||||
self.model.load_adapter(lora_path)
|
||||
|
||||
|
||||
def embed_texts(self, texts:list[str])->list[list[float]]:
|
||||
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
return self.model.encode(texts)
|
||||
if query:
|
||||
return self.model.encode_query(texts)
|
||||
else:
|
||||
return self.model.encode_document(texts)
|
||||
@ -5,40 +5,49 @@ from pydantic import BaseModel
|
||||
|
||||
from models.embedder_gemma import TextEmbedderGemma
|
||||
from models.embedder_gemma_train import TextEmbedderGemmaTrain
|
||||
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH
|
||||
from models.embedder_bge_train import TextEmbedderBgeTrain
|
||||
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH, BGE_MODEL_PATH, BGE_LORA_PATH
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_models():
|
||||
global embedder, embedder_train
|
||||
embedder = TextEmbedderGemma(GEMMA_MODEL_PATH)
|
||||
embedder_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
|
||||
global embed_gemma, embed_gemma_train, embed_bge_train
|
||||
embed_gemma = TextEmbedderGemma(GEMMA_MODEL_PATH)
|
||||
embed_gemma_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
|
||||
embed_bge_train = TextEmbedderBgeTrain(BGE_MODEL_PATH, BGE_LORA_PATH)
|
||||
print("Models loaded successfully")
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
model: str = "gemma"
|
||||
input: list[str] | str
|
||||
query: bool = False
|
||||
|
||||
|
||||
@app.post("/embed_gemma")
|
||||
@app.post("/embed_texts")
|
||||
def embed_gemma(request: EmbedRequest):
|
||||
"""
|
||||
Embed texts using the model
|
||||
Embed texts using the model.
|
||||
|
||||
Args:
|
||||
request: EmbedRequest
|
||||
*model : can be from these models : ["gemma", "gemma_train", "bge_train"]
|
||||
*input : it is a list of texts
|
||||
*query : your text can be query or passage. if it is query set this to true.
|
||||
Returns:
|
||||
data: list[dict]
|
||||
"""
|
||||
texts = request.input if isinstance(request.input, list) else [request.input]
|
||||
|
||||
if request.model == "gemma":
|
||||
embeddings = embedder.embed_texts(texts)
|
||||
embeddings = embed_gemma.embed_texts(texts, request.query)
|
||||
|
||||
elif request.model == "gemma_train":
|
||||
embeddings = embedder_train.embed_texts(texts)
|
||||
embeddings = embed_gemma_train.embed_texts(texts, request.query)
|
||||
|
||||
elif request.model == "bge_train":
|
||||
embeddings = embed_bge_train.embed_texts(texts, request.query)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid model")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user