text_clustering/post_sub_cluster.py
2025-10-22 14:31:39 +03:30

214 lines
8.1 KiB
Python

import asyncio
import aiohttp
import time
import re
import pandas as pd
import json
from tqdm import tqdm
class PostSubClusterLLM:
def __init__(self):
self.instruction = f"""
You will be given a title and a list of all cluster names.
Your task is to find the best fit cluster name for the title.
Go through the list of all cluster names and find the best fit cluster name for the title.
If you found a good fit, return the cluster name.
If you didn't find a good fit, return "outlier" is "yes".
#IMPORTANT:
- if you found a good fit use its id : {{"cluster" : "id_i"}}
- if the title is not related to any of the cluster names, return "outlier" is "yes" : {{"outlier" : "yes"}}
Example-1:
- Input:
- title: "کتاب و درس"
- all_cluster_names: {{
"1" : "کتابخوانی",
"2" : "فوتبال جام جهانی",
"3" : "ساختمان سازی شهری" }}
- Output:
- {{"cluster" : "1"}}
Example-2:
- Input:
- title: "لپتاب و کامپیوتر"
- all_cluster_names: {{
"1" : "کتابخوانی",
"2" : "فوتبال جام جهانی",
"3" : "ساختمان سازی شهری" }}
- Output:
- {{"outlier" : "yes"}}
Example-3:
- Input:
- title: "ساختمان"
- all_cluster_names: {{
"1" : "کتابخوانی",
"2" : "فوتبال جام جهانی",
"3" : "ساختمان سازی شهری" }}
- Output:
- {{"cluster" : "3"}}
write a small reason and give the final answer.
"""
async def run_llm(self, session, topic, cluster_name, cluster_sub_cluster_list):
"""
Run the LLM as reranker.
Args:
session: The session to use for the request.
question: The question to rerank the documents.
chunk: The chunk to rerank.
Returns:
The score of the chunk.
"""
if cluster_name == "متفرقه":
return None
headers = {"Content-Type": "application/json",}
for cluster_sub_cluster in cluster_sub_cluster_list:
if cluster_sub_cluster["cluster_name"] == cluster_name:
sub_cluster_names = cluster_sub_cluster["sub_cluster_names"]
break
sub_cluster_names_str = "{\n"
for count, value in enumerate(sub_cluster_names):
sub_cluster_names_str += f"{count} : {value},\n"
sub_cluster_names_str += "}"
input_message = f"""{{"all_cluster_names": "{sub_cluster_names_str}", "title": "{topic}"}}"""
messages = [{"role": "system", "content": self.instruction}, {"role": "user", "content": input_message}]
payload = {
"model": "google/gemma-3-27b-it",
"messages": messages,
"max_tokens": 500
}
try:
async with session.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload) as resp:
resp.raise_for_status()
response = await resp.json()
out = response['choices'][0]['message']['content']
print("--------------------------------")
print(f"title: {topic}")
print(f"cluster_name: {cluster_name}")
print(out)
pattern = r'(\{"cluster".*?\})'
matches = re.findall(pattern, out)
for m in matches:
out_json = json.loads(m)
print(f"out_json: {out_json}")
if out_json.get("cluster") is not None:
print(sub_cluster_names[int(out_json.get("cluster"))])
return out_json
pattern = r'(\{"outlier".*?\})'
matches = re.findall(pattern, out)
for m in matches:
out_json = json.loads(m)
print(f"out_json: {out_json}")
print("outlier")
return out_json
except Exception as e:
print(f"Error in llm as reranker: {e}")
return 0
async def run_llm_async(self, topics, cluster_names, cluster_sub_cluster_dict):
"""
Send all chunk requests concurrently.
Args:
topics: The topics to rerank.
cluster_names: The cluster names to rerank.
cluster_sub_cluster_dict: The cluster sub cluster dictionary.
Returns:
The scores of the chunks.
"""
async with aiohttp.ClientSession() as session:
tasks = [self.run_llm(session, topic, cluster_name, cluster_sub_cluster_dict) for topic, cluster_name in zip(topics, cluster_names)]
scores_embed = await asyncio.gather(*tasks)
return scores_embed
def sanitize_for_excel(self, df):
def _sanitize_for_excel(text):
"""Remove zero-width and bidi control characters that can confuse Excel rendering."""
if text is None:
return ""
s = str(text)
# Characters to remove: ZWNJ, ZWJ, RLM, LRM, RLE, LRE, PDF, BOM, Tatweel
remove_chars = [
"\u200c", # ZWNJ
"\u200d", # ZWJ
"\u200e", # LRM
"\u200f", # RLM
"\u202a", # LRE
"\u202b", # RLE
"\u202c", # PDF
"\u202d", # LRO
"\u202e", # RLO
"\ufeff", # BOM
"\u0640", # Tatweel
]
for ch in remove_chars:
s = s.replace(ch, "")
# Normalize whitespace
s = re.sub(r"\s+", " ", s).strip()
return s
df_copy = df.copy()
for m in df.columns:
for i in range(len(df_copy[m])):
df_copy.loc[i, m] = _sanitize_for_excel(df_copy.loc[i, m])
return df_copy
def start_process(self, input_path, titles_path, output_path):
df = pd.read_excel(input_path)
df_copy = df.copy()
with open(titles_path, "r") as f:
cluster_sub_cluster_list = json.load(f)
batch_size = 100
for i in tqdm(range(0, len(df["topic"]), batch_size)):
start_time = time.time()
result_list = asyncio.run(self.run_llm_async(df["topic"][i:i+batch_size], df["cluster_llm"][i:i+batch_size], cluster_sub_cluster_list))
end_time = time.time()
print(f"Time taken for llm as reranker: {end_time - start_time} seconds")
time.sleep(5)
for j, result in enumerate(result_list):
try:
if result is None:
df_copy.at[i+j, "sub_cluster"] = "متفرقه"
elif result.get("outlier") == "yes":
df_copy.at[i+j, "sub_cluster"] = "موارد دیگر"
elif result.get("cluster") is not None:
for cluster_sub_cluster in cluster_sub_cluster_list:
if cluster_sub_cluster["cluster_name"] == df["cluster_llm"][i+j]:
sub_cluster_names = cluster_sub_cluster["sub_cluster_names"]
break
df_copy.at[i+j, "sub_cluster"] = sub_cluster_names[int(result["cluster"])]
else:
df_copy.at[i+j, "sub_cluster"] = "موارد دیگر"
except Exception as e:
print(f"Error in result_list: {e}")
df_copy.at[i+j, "sub_cluster"] = "موارد دیگر"
print(df_copy.at[i+j, "sub_cluster"])
df_copy = self.sanitize_for_excel(df_copy)
df_copy.to_excel(output_path)
if __name__ == "__main__":
llm = PostSubClusterLLM()
llm.start_process("/home/firouzi/trend_grouping_new/tweet_topic_recreation_post_o3.xlsx", "/home/firouzi/trend_grouping_new/tweet_topic_recreation_post_o3_subcategory.json", "/home/firouzi/trend_grouping_new/tweet_topic_recreation_post_o3_subcategory.xlsx")