189 lines
7.3 KiB
Python
189 lines
7.3 KiB
Python
import asyncio
|
|
import aiohttp
|
|
import time
|
|
import re
|
|
import pandas as pd
|
|
import json
|
|
from tqdm import tqdm
|
|
|
|
class PostSubClusterLLM:
|
|
def __init__(self):
|
|
|
|
self.instruction = f"""
|
|
You will be given a title and a list of all cluster names.
|
|
Your task is to find the best fit cluster name for the title.
|
|
Go through the list of all cluster names and find the best fit cluster name for the title.
|
|
If you found a good fit, return the cluster name.
|
|
If you didn't find a good fit, return "outlier" is "yes".
|
|
|
|
#IMPORTANT:
|
|
- if you found a good fit use its id : {{"cluster" : "id_i"}}
|
|
- if the title is not related to any of the cluster names, return "outlier" is "yes" : {{"outlier" : "yes"}}
|
|
|
|
write a small reason and give the final answer.
|
|
"""
|
|
|
|
|
|
async def run_llm(self, session, topic, cluster_name, cluster_sub_cluster_list):
|
|
"""
|
|
Run the LLM as reranker.
|
|
Args:
|
|
session: The session to use for the request.
|
|
question: The question to rerank the documents.
|
|
chunk: The chunk to rerank.
|
|
Returns:
|
|
The score of the chunk.
|
|
"""
|
|
if cluster_name == "متفرقه":
|
|
return None
|
|
|
|
headers = {"Content-Type": "application/json",}
|
|
|
|
found_cluster = False
|
|
for cluster_sub_cluster in cluster_sub_cluster_list:
|
|
if cluster_sub_cluster["cluster_name"] == cluster_name:
|
|
sub_cluster_names = cluster_sub_cluster["sub_cluster_names"]
|
|
found_cluster = True
|
|
break
|
|
|
|
if not found_cluster:
|
|
return None
|
|
|
|
sub_cluster_names_str = "{\n"
|
|
for count, value in enumerate(sub_cluster_names):
|
|
sub_cluster_names_str += f"{count} : {value},\n"
|
|
|
|
sub_cluster_names_str += "}"
|
|
|
|
input_message = f"""{{"all_cluster_names": "{sub_cluster_names_str}", "title": "{topic}"}}"""
|
|
messages = [{"role": "system", "content": self.instruction}, {"role": "user", "content": input_message}]
|
|
|
|
payload = {
|
|
"model": "google/gemma-3-27b-it",
|
|
"messages": messages,
|
|
"max_tokens": 500
|
|
}
|
|
try:
|
|
async with session.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload) as resp:
|
|
resp.raise_for_status()
|
|
response = await resp.json()
|
|
|
|
out = response['choices'][0]['message']['content']
|
|
print("--------------------------------")
|
|
print(f"title: {topic}")
|
|
print(f"cluster_name: {cluster_name}")
|
|
print(out)
|
|
pattern = r'(\{"cluster".*?\})'
|
|
|
|
matches = re.findall(pattern, out)
|
|
|
|
for m in matches:
|
|
out_json = json.loads(m)
|
|
print(f"out_json: {out_json}")
|
|
if out_json.get("cluster") is not None:
|
|
print(sub_cluster_names[int(out_json.get("cluster"))])
|
|
return out_json
|
|
|
|
pattern = r'(\{"outlier".*?\})'
|
|
|
|
matches = re.findall(pattern, out)
|
|
|
|
for m in matches:
|
|
out_json = json.loads(m)
|
|
print(f"out_json: {out_json}")
|
|
print("outlier")
|
|
return out_json
|
|
except Exception as e:
|
|
print(f"Error in llm as reranker: {e}")
|
|
return 0
|
|
|
|
|
|
async def run_llm_async(self, topics, cluster_names, cluster_sub_cluster_dict):
|
|
"""
|
|
Send all chunk requests concurrently.
|
|
Args:
|
|
topics: The topics to rerank.
|
|
cluster_names: The cluster names to rerank.
|
|
cluster_sub_cluster_dict: The cluster sub cluster dictionary.
|
|
Returns:
|
|
The scores of the chunks.
|
|
"""
|
|
async with aiohttp.ClientSession() as session:
|
|
tasks = [self.run_llm(session, topic, cluster_name, cluster_sub_cluster_dict) for topic, cluster_name in zip(topics, cluster_names)]
|
|
scores_embed = await asyncio.gather(*tasks)
|
|
return scores_embed
|
|
|
|
def sanitize_for_excel(self, df):
|
|
def _sanitize_for_excel(text):
|
|
"""Remove zero-width and bidi control characters that can confuse Excel rendering."""
|
|
if text is None:
|
|
return ""
|
|
s = str(text)
|
|
# Characters to remove: ZWNJ, ZWJ, RLM, LRM, RLE, LRE, PDF, BOM, Tatweel
|
|
remove_chars = [
|
|
"\u200c", # ZWNJ
|
|
"\u200d", # ZWJ
|
|
"\u200e", # LRM
|
|
"\u200f", # RLM
|
|
"\u202a", # LRE
|
|
"\u202b", # RLE
|
|
"\u202c", # PDF
|
|
"\u202d", # LRO
|
|
"\u202e", # RLO
|
|
"\ufeff", # BOM
|
|
"\u0640", # Tatweel
|
|
]
|
|
for ch in remove_chars:
|
|
s = s.replace(ch, "")
|
|
# Normalize whitespace
|
|
s = re.sub(r"\s+", " ", s).strip()
|
|
return s
|
|
|
|
df_copy = df.copy()
|
|
for m in df.columns:
|
|
for i in range(len(df_copy[m])):
|
|
df_copy.loc[i, m] = _sanitize_for_excel(df_copy.loc[i, m])
|
|
|
|
return df_copy
|
|
|
|
def start_process(self, input_path, titles_path, output_path):
|
|
df = pd.read_excel(input_path)
|
|
df_copy = df.copy()
|
|
|
|
with open(titles_path, "r") as f:
|
|
cluster_sub_cluster_list = json.load(f)
|
|
|
|
batch_size = 100
|
|
for i in tqdm(range(0, len(df["topic"]), batch_size)):
|
|
start_time = time.time()
|
|
result_list = asyncio.run(self.run_llm_async(df["topic"][i:i+batch_size], df["cluster_llm"][i:i+batch_size], cluster_sub_cluster_list))
|
|
end_time = time.time()
|
|
print(f"Time taken for llm as reranker: {end_time - start_time} seconds")
|
|
time.sleep(5)
|
|
|
|
for j, result in enumerate(result_list):
|
|
try:
|
|
if result is None:
|
|
df_copy.at[i+j, "sub_cluster"] = "متفرقه"
|
|
elif result.get("outlier") == "yes":
|
|
df_copy.at[i+j, "sub_cluster"] = "موارد دیگر"
|
|
elif result.get("cluster") is not None:
|
|
for cluster_sub_cluster in cluster_sub_cluster_list:
|
|
if cluster_sub_cluster["cluster_name"] == df["cluster_llm"][i+j]:
|
|
sub_cluster_names = cluster_sub_cluster["sub_cluster_names"]
|
|
break
|
|
df_copy.at[i+j, "sub_cluster"] = sub_cluster_names[int(result["cluster"])]
|
|
else:
|
|
df_copy.at[i+j, "sub_cluster"] = "موارد دیگر"
|
|
|
|
except Exception as e:
|
|
print(f"Error in result_list: {e}")
|
|
df_copy.at[i+j, "sub_cluster"] = "موارد دیگر"
|
|
|
|
print(df_copy.at[i+j, "sub_cluster"])
|
|
df_copy = self.sanitize_for_excel(df_copy)
|
|
df_copy.to_excel(output_path)
|
|
|
|
if __name__ == "__main__":
|
|
llm = PostSubClusterLLM()
|
|
llm.start_process("/home/firouzi/trend_grouping_new/tweet_topic_recreation_post_o3.xlsx", "/home/firouzi/trend_grouping_new/tweet_topic_recreation_post_o3_subcategory.json", "/home/firouzi/trend_grouping_new/tweet_topic_recreation_post_o3_subcategory.xlsx") |