add all
This commit is contained in:
commit
3dd659fb7e
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@ -0,0 +1 @@
|
||||
.venv/*
|
||||
2
.env.example
Normal file
2
.env.example
Normal file
@ -0,0 +1,2 @@
|
||||
GEMMA_MODEL_PATH="./checkpoints/gemma/snapshots/57c266a740f537b4dc058e1b0cda161fd15afa75"
|
||||
GEMMA_LORA_PATH="./checkpoints/gemma_lora"
|
||||
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
__pycache__
|
||||
.venv
|
||||
.env
|
||||
checkpoints
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.12
|
||||
23
Dockerfile
Normal file
23
Dockerfile
Normal file
@ -0,0 +1,23 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
|
||||
# Install dependencies
|
||||
RUN uv sync --no-dev
|
||||
|
||||
# Add virtual environment to PATH
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
COPY . /app
|
||||
|
||||
EXPOSE 3037
|
||||
|
||||
CMD ["python", "-m", "gunicorn", "src.serve_embed:app", "--workers", "1", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:3037", "--timeout", "20"]
|
||||
6
README.md
Normal file
6
README.md
Normal file
@ -0,0 +1,6 @@
|
||||
## SERVE EMBEDDING MODEL
|
||||
|
||||
## MODELS
|
||||
|
||||
1-GEMMA-300M model
|
||||
2-GEMMA-300M model fin-tuned
|
||||
7
config/base.py
Normal file
7
config/base.py
Normal file
@ -0,0 +1,7 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
GEMMA_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH")
|
||||
GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH")
|
||||
6
main.py
Normal file
6
main.py
Normal file
@ -0,0 +1,6 @@
|
||||
def main():
|
||||
print("Hello from serve-embed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
13
models/embedder_gemma.py
Normal file
13
models/embedder_gemma.py
Normal file
@ -0,0 +1,13 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import requests
|
||||
|
||||
|
||||
class TextEmbedderGemma:
|
||||
def __init__(self, model_path):
|
||||
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
||||
|
||||
def embed_texts(self, texts:list[str])->list[list[float]]:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
return self.model.encode(texts)
|
||||
15
models/embedder_gemma_train.py
Normal file
15
models/embedder_gemma_train.py
Normal file
@ -0,0 +1,15 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import requests
|
||||
|
||||
|
||||
class TextEmbedderGemmaTrain:
|
||||
def __init__(self, model_path, lora_path):
|
||||
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
||||
self.model.load_adapter(lora_path)
|
||||
|
||||
|
||||
def embed_texts(self, texts:list[str])->list[list[float]]:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
return self.model.encode(texts)
|
||||
16
pyproject.toml
Normal file
16
pyproject.toml
Normal file
@ -0,0 +1,16 @@
|
||||
[project]
|
||||
name = "serve-embed"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.135.1",
|
||||
"gunicorn==23.0.0",
|
||||
"peft>=0.18.1",
|
||||
"python-dotenv==1.2.1",
|
||||
"requests>=2.32.5",
|
||||
"sentence-transformers>=5.2.3",
|
||||
"transformers==4.57.3",
|
||||
"uvicorn==0.40.0",
|
||||
]
|
||||
46
src/serve_embed.py
Normal file
46
src/serve_embed.py
Normal file
@ -0,0 +1,46 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
import uvicorn
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.embedder_gemma import TextEmbedderGemma
|
||||
from models.embedder_gemma_train import TextEmbedderGemmaTrain
|
||||
from config.base import GEMMA_MODEL_PATH, GEMMA_LORA_PATH
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_models():
|
||||
global embedder, embedder_train
|
||||
embedder = TextEmbedderGemma(GEMMA_MODEL_PATH)
|
||||
embedder_train = TextEmbedderGemmaTrain(GEMMA_MODEL_PATH, GEMMA_LORA_PATH)
|
||||
print("Models loaded successfully")
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
model: str = "gemma"
|
||||
input: list[str] | str
|
||||
|
||||
|
||||
@app.post("/embed_gemma")
|
||||
def embed_gemma(request: EmbedRequest):
|
||||
"""
|
||||
Embed texts using the model
|
||||
Args:
|
||||
request: EmbedRequest
|
||||
Returns:
|
||||
data: list[dict]
|
||||
"""
|
||||
texts = request.input if isinstance(request.input, list) else [request.input]
|
||||
|
||||
if request.model == "gemma":
|
||||
embeddings = embedder.embed_texts(texts)
|
||||
|
||||
elif request.model == "gemma_train":
|
||||
embeddings = embedder_train.embed_texts(texts)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid model")
|
||||
|
||||
return {"data": [{"embedding": emb.tolist()} for emb in embeddings]}
|
||||
Loading…
x
Reference in New Issue
Block a user