30 lines
946 B
Python
30 lines
946 B
Python
from sentence_transformers import SentenceTransformer
|
|
import requests
|
|
from config.base import BATCH_SIZE
|
|
import torch
|
|
|
|
|
|
class TextEmbedderBgeTrain:
|
|
def __init__(self, model_path, lora_path):
|
|
self.model = SentenceTransformer(model_path, trust_remote_code=True, local_files_only=True).to(device="cuda:0")
|
|
self.model.load_adapter(lora_path)
|
|
|
|
|
|
def embed_texts(self, texts:list[str], query:bool = False)->list[list[float]]:
|
|
"""
|
|
Embed texts using the model.
|
|
"""
|
|
all_embeddings = []
|
|
for i in range(0, len(texts), BATCH_SIZE):
|
|
batch_texts = texts[i:i+BATCH_SIZE]
|
|
|
|
if query:
|
|
embeddings = self.model.encode_query(batch_texts)
|
|
else:
|
|
embeddings = self.model.encode_document(batch_texts)
|
|
|
|
all_embeddings.extend(embeddings)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
return all_embeddings |