text_clustering/post_cluster.py
2025-10-21 11:14:59 +03:30

203 lines
7.0 KiB
Python

import asyncio
import aiohttp
import time
import re
import pandas as pd
import json
from tqdm import tqdm
class PostClusterLLM:
def __init__(self):
self.instruction = f"""
You will be given a title and a list of all cluster names.
Your task is to find the best fit cluster name for the title.
Go through the list of all cluster names and find the best fit cluster name for the title.
If you found a good fit, return the cluster name.
If you didn't find a good fit, return "outlier" is "yes".
#IMPORTANT:
- if you found a good fit use its id : {{"cluster" : "id_i"}}
- if the title is not related to any of the cluster names, return "outlier" is "yes" : {{"outlier" : "yes"}}
Example-1:
- Input:
- title: "کتاب و درس"
- all_cluster_names: {{
"1" : "کتابخوانی",
"2" : "فوتبال جام جهانی",
"3" : "ساختمان سازی شهری" }}
- Output:
- {{"cluster" : "1"}}
Example-2:
- Input:
- title: "لپتاب و کامپیوتر"
- all_cluster_names: {{
"1" : "کتابخوانی",
"2" : "فوتبال جام جهانی",
"3" : "ساختمان سازی شهری" }}
- Output:
- {{"outlier" : "yes"}}
Example-3:
- Input:
- title: "ساختمان"
- all_cluster_names: {{
"1" : "کتابخوانی",
"2" : "فوتبال جام جهانی",
"3" : "ساختمان سازی شهری" }}
- Output:
- {{"cluster" : "3"}}
write a small reason and give the final answer.
"""
async def run_llm(self, session, title, cluster_names):
"""
Run the LLM as reranker.
Args:
session: The session to use for the request.
question: The question to rerank the documents.
chunk: The chunk to rerank.
Returns:
The score of the chunk.
"""
headers = {"Content-Type": "application/json",}
input_message = f"""{{"all_cluster_names": "{cluster_names}", "title": "{title}"}}"""
messages = [{"role": "system", "content": self.instruction}, {"role": "user", "content": input_message}]
payload = {
"model": "google/gemma-3-27b-it",
"messages": messages,
"max_tokens": 500
}
try:
async with session.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload) as resp:
resp.raise_for_status()
response = await resp.json()
out = response['choices'][0]['message']['content']
print("--------------------------------")
print(f"title: {title}")
print(out)
pattern = r'(\{"cluster".*?\})'
matches = re.findall(pattern, out)
for m in matches:
out_json = json.loads(m)
print(f"out_json: {out_json}")
return out_json
pattern = r'(\{"outlier".*?\})'
matches = re.findall(pattern, out)
for m in matches:
out_json = json.loads(m)
print(f"out_json: {out_json}")
return out_json
except Exception as e:
print(f"Error in llm as reranker: {e}")
return 0
async def run_llm_async(self, titles, cluster_names):
"""
Send all chunk requests concurrently.
Args:
titles: The titles to rerank.
possible_cluster_names: The possible cluster names to rerank.
cluster_names: The cluster names to rerank.
Returns:
The scores of the chunks.
"""
async with aiohttp.ClientSession() as session:
tasks = [self.run_llm(session, title, cluster_names) for title in titles]
scores_embed = await asyncio.gather(*tasks)
return scores_embed
def sanitize_for_excel(self, df):
def _sanitize_for_excel(text):
"""Remove zero-width and bidi control characters that can confuse Excel rendering."""
if text is None:
return ""
s = str(text)
# Characters to remove: ZWNJ, ZWJ, RLM, LRM, RLE, LRE, PDF, BOM, Tatweel
remove_chars = [
"\u200c", # ZWNJ
"\u200d", # ZWJ
"\u200e", # LRM
"\u200f", # RLM
"\u202a", # LRE
"\u202b", # RLE
"\u202c", # PDF
"\u202d", # LRO
"\u202e", # RLO
"\ufeff", # BOM
"\u0640", # Tatweel
]
for ch in remove_chars:
s = s.replace(ch, "")
# Normalize whitespace
s = re.sub(r"\s+", " ", s).strip()
return s
df_copy = df.copy()
for m in df.columns:
for i in range(len(df_copy[m])):
df_copy.loc[i, m] = _sanitize_for_excel(df_copy.loc[i, m])
return df_copy
def start_process(self, input_path, output_path):
df = pd.read_excel(input_path)
df_copy = df.copy()
with open("titles_o3.txt", "r") as f:
titles = f.readlines()
titles = [title.strip() for title in titles]
cluster_names_dict = {}
count = 1
for item in titles:
cluster_names_dict[str(count)] = item
count += 1
cluster_names = "{\n"
for key, value in cluster_names_dict.items():
cluster_names += f"{key} : {value},\n"
cluster_names += "}"
batch_size = 100
for i in tqdm(range(0, len(df["topic"]), batch_size)):
start_time = time.time()
result_list = asyncio.run(self.run_llm_async(df["topic"][i:i+batch_size], cluster_names))
end_time = time.time()
print(f"Time taken for llm as reranker: {end_time - start_time} seconds")
time.sleep(5)
for j, result in enumerate(result_list):
try:
if result.get("outlier") == "yes":
df_copy.at[i+j, "cluster_llm"] = "متفرقه"
elif result.get("cluster") is not None:
df_copy.at[i+j, "cluster_llm"] = cluster_names_dict[result["cluster"]]
else:
df_copy.at[i+j, "cluster_llm"] = df_copy.at[i+j, "category"]
except Exception as e:
print(f"Error in result_list: {e}")
df_copy.at[i+j, "cluster_llm"] = df_copy.at[i+j, "category"]
df_copy = self.sanitize_for_excel(df_copy)
df_copy.to_excel(output_path)
if __name__ == "__main__":
llm = PostClusterLLM()
llm.start_process("/home/firouzi/trend_grouping_new/tweet_topic_recreation.xlsx", "/home/firouzi/trend_grouping_new/tweet_topic_recreation_post_o3.xlsx")