serve_embed/models/embedder_gemma.py
2026-03-14 11:23:42 +03:30

27 lines
895 B
Python

from sentence_transformers import SentenceTransformer
import requests
import torch
from config.base import BATCH_SIZE
class TextEmbedderGemma:
def __init__(self, model_path):
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
"""
Embed texts using the model.
"""
all_embeddings = []
for i in range(0, len(texts), BATCH_SIZE):
batch_texts = texts[i:i+BATCH_SIZE]
if query:
embeddings = self.model.encode_query(batch_texts)
else:
embeddings = self.model.encode_document(batch_texts)
all_embeddings.extend(embeddings)
torch.cuda.empty_cache()
return all_embeddings