diff --git a/src/serve_embed.py b/src/serve_embed.py index 5173147..f388b53 100644 --- a/src/serve_embed.py +++ b/src/serve_embed.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from fastapi import HTTPException import uvicorn from pydantic import BaseModel +import torch from models.embedder_gemma import TextEmbedderGemma from models.embedder_gemma_train import TextEmbedderGemmaTrain @@ -52,4 +53,6 @@ def embed_gemma(request: EmbedRequest): else: raise HTTPException(status_code=400, detail="Invalid model") + torch.cuda.empty_cache() + return {"data": [{"embedding": emb.tolist()} for emb in embeddings]} \ No newline at end of file