add all
This commit is contained in:
commit
3dd659fb7e
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@ -0,0 +1 @@
|
|||||||
|
.venv/*
|
||||||
2
.env.example
Normal file
2
.env.example
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
GEMMA_MODEL_PATH="./checkpoints/gemma/snapshots/57c266a740f537b4dc058e1b0cda161fd15afa75"
|
||||||
|
GEMMA_LORA_PATH="./checkpoints/gemma_lora"
|
||||||
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
__pycache__
|
||||||
|
.venv
|
||||||
|
.env
|
||||||
|
checkpoints
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
23
Dockerfile
Normal file
23
Dockerfile
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
PYTHONDONTWRITEBYTECODE=1
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml uv.lock /app/
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN uv sync --no-dev
|
||||||
|
|
||||||
|
# Add virtual environment to PATH
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
EXPOSE 3037
|
||||||
|
|
||||||
|
CMD ["python", "-m", "gunicorn", "src.serve_embed:app", "--workers", "1", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:3037", "--timeout", "20"]
|
||||||
6
README.md
Normal file
6
README.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
## SERVE EMBEDDING MODEL
|
||||||
|
|
||||||
|
## MODELS
|
||||||
|
|
||||||
|
1-GEMMA-300M model
|
||||||
|
2-GEMMA-300M model fin-tuned
|
||||||
7
config/base.py
Normal file
7
config/base.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
|
||||||
|
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
||||||
6
main.py
Normal file
6
main.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from serve-embed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
13
models/embedder_gemma.py
Normal file
13
models/embedder_gemma.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedderGemma:
|
||||||
|
def __init__(self, model_path):
|
||||||
|
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
||||||
|
|
||||||
|
def embed_texts(self, texts:list[str])->list[list[float]]:
|
||||||
|
"""
|
||||||
|
Embed texts using the model.
|
||||||
|
"""
|
||||||
|
return self.model.encode(texts)
|
||||||
15
models/embedder_gemma_train.py
Normal file
15
models/embedder_gemma_train.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedderGemmaTrain:
|
||||||
|
def __init__(self, model_path, lora_path):
|
||||||
|
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
||||||
|
self.model.load_adapter(lora_path)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_texts(self, texts:list[str])->list[list[float]]:
|
||||||
|
"""
|
||||||
|
Embed texts using the model.
|
||||||
|
"""
|
||||||
|
return self.model.encode(texts)
|
||||||
16
pyproject.toml
Normal file
16
pyproject.toml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
[project]
|
||||||
|
name = "serve-embed"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.135.1",
|
||||||
|
"gunicorn==23.0.0",
|
||||||
|
"peft>=0.18.1",
|
||||||
|
"python-dotenv==1.2.1",
|
||||||
|
"requests>=2.32.5",
|
||||||
|
"sentence-transformers>=5.2.3",
|
||||||
|
"transformers==4.57.3",
|
||||||
|
"uvicorn==0.40.0",
|
||||||
|
]
|
||||||
46
src/serve_embed.py
Normal file
46
src/serve_embed.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import uvicorn
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from models.embedder_gemma import TextEmbedderGemma
|
||||||
|
from models.embedder_gemma_train import TextEmbedderGemmaTrain
|
||||||
|
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def load_models():
|
||||||
|
global embedder, embedder_train
|
||||||
|
embedder = TextEmbedderGemma(GEMMA_MODEL_PATH)
|
||||||
|
embedder_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
|
||||||
|
print("Models loaded successfully")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedRequest(BaseModel):
|
||||||
|
model: str = "gemma"
|
||||||
|
input: list[str] | str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/embed_gemma")
|
||||||
|
def embed_gemma(request: EmbedRequest):
|
||||||
|
"""
|
||||||
|
Embed texts using the model
|
||||||
|
Args:
|
||||||
|
request: EmbedRequest
|
||||||
|
Returns:
|
||||||
|
data: list[dict]
|
||||||
|
"""
|
||||||
|
texts = request.input if isinstance(request.input, list) else [request.input]
|
||||||
|
|
||||||
|
if request.model == "gemma":
|
||||||
|
embeddings = embedder.embed_texts(texts)
|
||||||
|
|
||||||
|
elif request.model == "gemma_train":
|
||||||
|
embeddings = embedder_train.embed_texts(texts)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid model")
|
||||||
|
|
||||||
|
return {"data": [{"embedding": emb.tolist()} for emb in embeddings]}
|
||||||
Loading…
x
Reference in New Issue
Block a user