add bge
This commit is contained in:
parent
3dd659fb7e
commit
8eebc192e0
@ -4,4 +4,6 @@ from dotenv import load_dotenv
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
|
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
|
||||||
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
||||||
|
BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH")
|
||||||
|
BGE_LORA_PATH = os.getenv("BGE_LORA_PATH")
|
||||||
6
main.py
6
main.py
@ -1,6 +0,0 @@
|
|||||||
def main():
|
|
||||||
print("Hello from serve-embed!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
18
models/embedder_bge_train.py
Normal file
18
models/embedder_bge_train.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedderBgeTrain:
|
||||||
|
def __init__(self, model_path, lora_path):
|
||||||
|
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
||||||
|
self.model.load_adapter(lora_path)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
|
||||||
|
"""
|
||||||
|
Embed texts using the model.
|
||||||
|
"""
|
||||||
|
if query:
|
||||||
|
return self.model.encode_query(texts)
|
||||||
|
else:
|
||||||
|
return self.model.encode_document(texts)
|
||||||
@ -6,8 +6,11 @@ class TextEmbedderGemma:
|
|||||||
def __init__(self, model_path):
|
def __init__(self, model_path):
|
||||||
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
||||||
|
|
||||||
def embed_texts(self, texts:list[str])->list[list[float]]:
|
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Embed texts using the model.
|
Embed texts using the model.
|
||||||
"""
|
"""
|
||||||
return self.model.encode(texts)
|
if query:
|
||||||
|
return self.model.encode_query(texts)
|
||||||
|
else:
|
||||||
|
return self.model.encode_document(texts)
|
||||||
@ -8,8 +8,11 @@ class TextEmbedderGemmaTrain:
|
|||||||
self.model.load_adapter(lora_path)
|
self.model.load_adapter(lora_path)
|
||||||
|
|
||||||
|
|
||||||
def embed_texts(self, texts:list[str])->list[list[float]]:
|
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Embed texts using the model.
|
Embed texts using the model.
|
||||||
"""
|
"""
|
||||||
return self.model.encode(texts)
|
if query:
|
||||||
|
return self.model.encode_query(texts)
|
||||||
|
else:
|
||||||
|
return self.model.encode_document(texts)
|
||||||
@ -5,40 +5,49 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from models.embedder_gemma import TextEmbedderGemma
|
from models.embedder_gemma import TextEmbedderGemma
|
||||||
from models.embedder_gemma_train import TextEmbedderGemmaTrain
|
from models.embedder_gemma_train import TextEmbedderGemmaTrain
|
||||||
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH
|
from models.embedder_bge_train import TextEmbedderBgeTrain
|
||||||
|
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH, BGE_MODEL_PATH, BGE_LORA_PATH
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def load_models():
|
def load_models():
|
||||||
global embedder, embedder_train
|
global embed_gemma, embed_gemma_train, embed_bge_train
|
||||||
embedder = TextEmbedderGemma(GEMMA_MODEL_PATH)
|
embed_gemma = TextEmbedderGemma(GEMMA_MODEL_PATH)
|
||||||
embedder_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
|
embed_gemma_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
|
||||||
|
embed_bge_train = TextEmbedderBgeTrain(BGE_MODEL_PATH, BGE_LORA_PATH)
|
||||||
print("Models loaded successfully")
|
print("Models loaded successfully")
|
||||||
|
|
||||||
|
|
||||||
class EmbedRequest(BaseModel):
|
class EmbedRequest(BaseModel):
|
||||||
model: str = "gemma"
|
model: str = "gemma"
|
||||||
input: list[str] | str
|
input: list[str] | str
|
||||||
|
query: bool = False
|
||||||
|
|
||||||
|
|
||||||
@app.post("/embed_gemma")
|
@app.post("/embed_texts")
|
||||||
def embed_gemma(request: EmbedRequest):
|
def embed_gemma(request: EmbedRequest):
|
||||||
"""
|
"""
|
||||||
Embed texts using the model
|
Embed texts using the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: EmbedRequest
|
*model : can be from these models : ["gemma", "gemma_train", "bge_train"]
|
||||||
|
*input : it is a list of texts
|
||||||
|
*query : your text can be query or passage. if it is query set this to true.
|
||||||
Returns:
|
Returns:
|
||||||
data: list[dict]
|
data: list[dict]
|
||||||
"""
|
"""
|
||||||
texts = request.input if isinstance(request.input, list) else [request.input]
|
texts = request.input if isinstance(request.input, list) else [request.input]
|
||||||
|
|
||||||
if request.model == "gemma":
|
if request.model == "gemma":
|
||||||
embeddings = embedder.embed_texts(texts)
|
embeddings = embed_gemma.embed_texts(texts, request.query)
|
||||||
|
|
||||||
elif request.model == "gemma_train":
|
elif request.model == "gemma_train":
|
||||||
embeddings = embedder_train.embed_texts(texts)
|
embeddings = embed_gemma_train.embed_texts(texts, request.query)
|
||||||
|
|
||||||
|
elif request.model == "bge_train":
|
||||||
|
embeddings = embed_bge_train.embed_texts(texts, request.query)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=400, detail="Invalid model")
|
raise HTTPException(status_code=400, detail="Invalid model")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user