diff --git a/config/base.py b/config/base.py index 8b3b161..ee0e2f9 100644 --- a/config/base.py +++ b/config/base.py @@ -8,4 +8,4 @@ GEMMA_LORA_PATH = os.getenv("GEMMA_LORA_PATH") BGE_MODEL_PATH = os.getenv("BGE_MODEL_PATH") BGE_LORA_PATH = os.getenv("BGE_LORA_PATH") -BATCH_SIZE = 250 \ No newline at end of file +BATCH_SIZE = 100 \ No newline at end of file diff --git a/src/serve_embed.py b/src/serve_embed.py index 6caa30c..b537f5a 100644 --- a/src/serve_embed.py +++ b/src/serve_embed.py @@ -55,6 +55,7 @@ def embed_gemma(request: EmbedRequest): raise HTTPException(status_code=400, detail="Invalid model") except Exception as e: + torch.cuda.empty_cache() raise HTTPException(status_code=500, detail=str(e)) torch.cuda.empty_cache()