110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
from typing import List, Dict, Any
|
|
import json
|
|
import asyncio
|
|
import aiohttp
|
|
import time
|
|
import re
|
|
import os
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
model = os.getenv('LLM_AS_RERANKER_MODEL')
|
|
model_url = os.getenv('LLM_AS_RERANKER_URL')
|
|
model_pass = os.getenv('LLM_AS_RERANKER_PASS')
|
|
|
|
class LLMModel:
|
|
def __init__(self):
|
|
|
|
self.instruction = """
|
|
You are a helpful assistant that help me to find that the text is relevant to the question or not.
|
|
You are given a question and a text.
|
|
You must evaluate the text based on the question and return "1" if the text is relevant to the question and "0" if the text is not relevant to the question.
|
|
|
|
be carefull, I have chosen the text randomly from my dataset so the text must answer the question independently.
|
|
You must return the result in the following format:
|
|
{{"result": "1" or "0"}}
|
|
"""
|
|
|
|
async def run_llm(self, session, question, text):
|
|
"""
|
|
Run the llm model.
|
|
Args:
|
|
session: The session to use for the request.
|
|
question: The question to evaluate the text.
|
|
text: The text to evaluate.
|
|
Returns:
|
|
The result of the text.
|
|
"""
|
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {model_pass}"}
|
|
|
|
input_message = f"""{{"question": "{question}", "text": "{text}"}}"""
|
|
messages = [{"role": "system", "content": self.instruction}, {"role": "user", "content": input_message}]
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"max_tokens": 100
|
|
}
|
|
try:
|
|
async with session.post(model_url + "/chat/completions", headers=headers, json=payload) as resp:
|
|
resp.raise_for_status()
|
|
response = await resp.json()
|
|
|
|
out = response['choices'][0]['message']['content']
|
|
|
|
match = re.search(r'"result":\s*"?([\d\.]+)"?', out)
|
|
|
|
if match:
|
|
result = match.group(1)
|
|
|
|
if result not in ["0", "1"]:
|
|
print(f"Error in llm model {out}: {e}")
|
|
return "0"
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
try:
|
|
print(f"Error in llm model {out}: {e}")
|
|
except:
|
|
print(f"Error in llm model: {e}")
|
|
return "0"
|
|
|
|
|
|
async def run_llm_async(self, question_list, text_list):
|
|
"""
|
|
Send all chunk requests concurrently.
|
|
Args:
|
|
question_list: The list of questions.
|
|
text_list: The list of texts.
|
|
Returns:
|
|
The list of results.
|
|
"""
|
|
async with aiohttp.ClientSession() as session:
|
|
tasks = [self.run_llm(session, question, text) for question, text in zip(question_list, text_list)]
|
|
results = await asyncio.gather(*tasks)
|
|
return results
|
|
|
|
|
|
def remove_false_negative_llm(self, query_list: List[str], text_list: List[str]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Remove false negative samples from the documents based on the query using the LLM model.
|
|
Args:
|
|
query_list: The list of queries.
|
|
text_list: The list of texts.
|
|
Returns:
|
|
The list of texts that are relevant to the queries.
|
|
"""
|
|
if not text_list:
|
|
return []
|
|
|
|
start_time = time.time()
|
|
results = asyncio.run(self.run_llm_async(query_list, text_list))
|
|
end_time = time.time()
|
|
print(f"Time taken for llm model: {end_time - start_time} seconds")
|
|
|
|
return results
|
|
|
|
|