46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
from hazm import Normalizer
|
|
import requests
|
|
import numpy as np
|
|
from dotenv import load_dotenv
|
|
import os
|
|
import time
|
|
|
|
load_dotenv()
|
|
|
|
|
|
class TextEmbedder:
|
|
def __init__(self, model_name="BAAI/bge-m3"):
|
|
self.model_name = model_name
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {os.getenv('EMBEDDING_PASS')}"}
|
|
self.normalizer = Normalizer()
|
|
|
|
def preprocess_embedder(self, text:str):
|
|
text = text.replace("\n", ".")
|
|
text = self.normalizer.normalize(text)
|
|
|
|
return text
|
|
|
|
|
|
def embed_texts(self, texts:list[str], do_preprocess=True, convert_to_numpy=True)->list[list[float]]:
|
|
"""
|
|
Embed texts using the model.
|
|
"""
|
|
if texts == []:
|
|
return []
|
|
|
|
if do_preprocess:
|
|
texts = [self.preprocess_embedder(text) for text in texts]
|
|
|
|
payload = {
|
|
"model": self.model_name,
|
|
"input": texts
|
|
}
|
|
responses = requests.post("http://78.38.161.78:3094/v1/embeddings", headers=self.headers, json=payload)
|
|
|
|
if convert_to_numpy:
|
|
embeddings = [np.array(response["embedding"], dtype=np.float32) for response in responses.json()["data"]]
|
|
else:
|
|
embeddings = [response["embedding"] for response in responses.json()["data"]]
|
|
|
|
return embeddings
|
|
|