This commit is contained in:
SFirouzi 2026-03-07 16:51:27 +03:30
commit 3dd659fb7e
13 changed files with 1401 additions and 0 deletions

1
.dockerignore Normal file
View File

@ -0,0 +1 @@
.venv/*

2
.env.example Normal file
View File

@ -0,0 +1,2 @@
GEMMA_MODEL_PATH="./checkpoints/gemma/snapshots/57c266a740f537b4dc058e1b0cda161fd15afa75"
GEMMA_LORA_PATH="./checkpoints/gemma_lora"

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
__pycache__
.venv
.env
checkpoints

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.12

23
Dockerfile Normal file
View File

@ -0,0 +1,23 @@
FROM python:3.12-slim
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
ENV DEBIAN_FRONTEND=noninteractive \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1
WORKDIR /app
COPY pyproject.toml uv.lock /app/
# Install dependencies
RUN uv sync --no-dev
# Add virtual environment to PATH
ENV PATH="/app/.venv/bin:$PATH"
COPY . /app
EXPOSE 3037
CMD ["python", "-m", "gunicorn", "src.serve_embed:app", "--workers", "1", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:3037", "--timeout", "20"]

6
README.md Normal file
View File

@ -0,0 +1,6 @@
## SERVE EMBEDDING MODEL
## MODELS
1-GEMMA-300M model
2-GEMMA-300M model fin-tuned

7
config/base.py Normal file
View File

@ -0,0 +1,7 @@
import os
from dotenv import load_dotenv
load_dotenv()
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")

6
main.py Normal file
View File

@ -0,0 +1,6 @@
def main():
print("Hello from serve-embed!")
if __name__ == "__main__":
main()

13
models/embedder_gemma.py Normal file
View File

@ -0,0 +1,13 @@
from sentence_transformers import SentenceTransformer
import requests
class TextEmbedderGemma:
def __init__(self, model_path):
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
def embed_texts(self, texts:list[str])->list[list[float]]:
"""
Embed texts using the model.
"""
return self.model.encode(texts)

View File

@ -0,0 +1,15 @@
from sentence_transformers import SentenceTransformer
import requests
class TextEmbedderGemmaTrain:
def __init__(self, model_path, lora_path):
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
self.model.load_adapter(lora_path)
def embed_texts(self, texts:list[str])->list[list[float]]:
"""
Embed texts using the model.
"""
return self.model.encode(texts)

16
pyproject.toml Normal file
View File

@ -0,0 +1,16 @@
[project]
name = "serve-embed"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"fastapi>=0.135.1",
"gunicorn==23.0.0",
"peft>=0.18.1",
"python-dotenv==1.2.1",
"requests>=2.32.5",
"sentence-transformers>=5.2.3",
"transformers==4.57.3",
"uvicorn==0.40.0",
]

46
src/serve_embed.py Normal file
View File

@ -0,0 +1,46 @@
from fastapi import FastAPI
from fastapi import HTTPException
import uvicorn
from pydantic import BaseModel
from models.embedder_gemma import TextEmbedderGemma
from models.embedder_gemma_train import TextEmbedderGemmaTrain
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH
app = FastAPI()
@app.on_event("startup")
def load_models():
global embedder, embedder_train
embedder = TextEmbedderGemma(GEMMA_MODEL_PATH)
embedder_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
print("Models loaded successfully")
class EmbedRequest(BaseModel):
model: str = "gemma"
input: list[str] | str
@app.post("/embed_gemma")
def embed_gemma(request: EmbedRequest):
"""
Embed texts using the model
Args:
request: EmbedRequest
Returns:
data: list[dict]
"""
texts = request.input if isinstance(request.input, list) else [request.input]
if request.model == "gemma":
embeddings = embedder.embed_texts(texts)
elif request.model == "gemma_train":
embeddings = embedder_train.embed_texts(texts)
else:
raise HTTPException(status_code=400, detail="Invalid model")
return {"data": [{"embedding": emb.tolist()} for emb in embeddings]}

1261
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff