text_clustering/sub_clustering_pipeline.py
2025-10-21 11:14:59 +03:30

285 lines
9.3 KiB
Python

import argparse
import pandas as pd
from transformers import AutoModel
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from hazm import Normalizer
from tqdm import tqdm
import requests
from openai import OpenAI
import httpx
import random
import re
import json
START_K = 2
END_K = 60
def sanitize_for_excel(text):
"""Remove zero-width and bidi control characters that can confuse Excel rendering."""
if text is None:
return ""
s = str(text)
# Characters to remove: ZWNJ, ZWJ, RLM, LRM, RLE, LRE, PDF, BOM, Tatweel
remove_chars = [
"\u200c", # ZWNJ
"\u200d", # ZWJ
"\u200e", # LRM
"\u200f", # RLM
"\u202a", # LRE
"\u202b", # RLE
"\u202c", # PDF
"\u202d", # LRO
"\u202e", # RLO
"\ufeff", # BOM
"\u0640", # Tatweel
]
for ch in remove_chars:
s = s.replace(ch, "")
# Normalize whitespace
s = re.sub(r"\s+", " ", s).strip()
return s
def get_best_k(embeddings):
max_sil_score = 0
best_k = START_K
for k in range(START_K, min(END_K, len(embeddings))):
kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
labels = kmeans.fit_predict(embeddings)
sil_score = silhouette_score(embeddings, labels)
if sil_score > max_sil_score:
max_sil_score = sil_score
best_k = k
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=10)
labels = kmeans.fit_predict(embeddings)
return best_k, labels
def get_embeddings(names):
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to("cuda")
normalizer = Normalizer()
names = [normalizer.normalize(name) for name in names]
adjs = ["توهین", "انتقاد", "نقد", "حمایت", "مسائل", "مربوط", "تهدید", "عملکرد", "رفتار", "به", "از", "در"]
names_new = []
for name in names:
for adj in adjs:
name = name.replace(adj, "")
names_new.append(name)
embeddings = []
for batch in tqdm(range(0, len(names_new), 50)):
embeddings += model.encode(names_new[batch:batch+50], task="separation").tolist()
return embeddings
def get_cluster_names(clusters):
headers = {"Content-Type": "application/json",}
prompt = """
You are a helpful assistant that generates names for clusters of topics in persian.
I will give you a list of topics and you will generate a name for this cluster.
There might be some different topics in the list so you just consider the dominant topic.
be specific about the cluster name.
Just give me the final answer in persian.
"""
cluster_names = []
for data in clusters:
if len(data) < 10:
continue
cluster_samples = random.sample(data, min(20, len(data)))
messages = [{"role": "system", "content": prompt}, {"role": "user", "content": str(cluster_samples)}]
payload = {
"model": "google/gemma-3-27b-it",
"messages": messages,
"max_tokens": 8000
}
response = requests.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload)
our_response = response.json()['choices'][0]['message']['content']
cluster_names.append(our_response)
return cluster_names
def modify_cluster_names(cluster_names, title, best_k):
PROXY_URL = "http://2zajDvJvJg:e0BtBiynhF@192.168.130.40:51371/"
http_client = httpx.Client(proxy=PROXY_URL)
client = OpenAI(api_key="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", http_client=http_client)
start = (best_k / 2) - ((best_k / 2) % 10)
if start == 0:
start = 1
prompt = f"""
You are a sub category modification expert.
I will give you a list of topics.
all these topics belongs to {title} category
## TASK
Extract meaningful and distinct sub category from the list. you can change the name of topics. Just about {start}-{start+10} topics that cover all of them.
## RULES
- You can combine or split or ... for doing this task.
- You can change the name of topics to make it more general or more specific.
- the final topics must be distinct and have specific meaning rather than others.
- dont combine topics that are not related to each other. like economical with political with social with ...
- combine topics that are related to each other. like ghaza with palestine or ...
## MUST
- all sub categories must be distinct and have specific meaning from other categories.
- two categories can not be similar to each other.
- be specifc about sub categories
I will trust your intelligence.
write the final answer in persian.
"""
response = client.chat.completions.create(
model="o3",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": str(cluster_names)}
]
)
out = response.choices[0].message.content
return out
def extract_list(text, count):
headers = {"Content-Type": "application/json",}
prompt = """
extract the titles from this text and put it in a list.
just return the output in list format, do not include any other text : ["title_1", "title_2", ...]
"""
messages = [{"role": "system", "content": prompt}, {"role": "user", "content": text}]
payload = {
"model": "google/gemma-3-27b-it",
"messages": messages,
"max_tokens": 8000
}
response = requests.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload)
out = response.json()['choices'][0]['message']['content']
try:
out = json.loads(out)
except:
print(f"error in extract list {count}")
return out
def main(input_file, output_file):
# read input file
df = pd.read_excel(input_file)
topics = df["topic"].tolist()
cluster_llms = df["cluster_llm"].tolist()
# get embeddings
embeddings = get_embeddings(topics)
# extract main cluster names
cluster_names = []
with open("titles_o3.txt", "r") as f:
titles = f.readlines()
titles = [sanitize_for_excel(title.strip()) for title in titles]
embedding_cluster = []
best_k = len(titles)
for i in range(best_k):
embedding_cluster.append([])
topic_cluster = []
best_k = len(titles)
for i in range(best_k):
topic_cluster.append([])
for m in range(len(titles)):
for embedding, cluster_name, topic in zip(embeddings, cluster_llms, topics):
if cluster_name == titles[m]:
embedding_cluster[m].append(embedding)
topic_cluster[m].append(topic)
sub_cluster_names = []
for cluster_count in tqdm(range(len(titles))):
print(f"start {cluster_count} \n")
# get best k and labels of kmeans with best_k
best_k, labels = get_best_k(embedding_cluster[cluster_count])
print(f"initial best_k {best_k}\n")
# fill clusters
clusters = []
for i in range(best_k):
clusters.append([])
for i in range(len(clusters)):
for topic, label in zip(topic_cluster[cluster_count], labels):
if label == i:
clusters[i].append(topic)
# get cluster names
cluster_names = get_cluster_names(clusters)
if len(cluster_names) > 1:
# get embeddings for cluster names
cluster_names_embeddings = get_embeddings(cluster_names)
# get best k and labels of kmeans with best_k
best_k_cluster_names, labels_cluster_names = get_best_k(cluster_names_embeddings)
print(f"second best_k {best_k_cluster_names}\n")
# fill clusters of cluster_names
clusters_cluster_names = []
for i in range(best_k_cluster_names):
clusters_cluster_names.append([])
for i in range(len(clusters_cluster_names)):
for cluster_name, label in zip(cluster_names, labels_cluster_names):
if label == i:
clusters_cluster_names[i].append(cluster_name)
# get cluster names for clusters of cluster_names
cluster_names_modify = modify_cluster_names(clusters_cluster_names, titles[cluster_count], best_k)
cluster_names_modify_list = extract_list(cluster_names_modify, cluster_count)
sub_cluster_names.append({"id": cluster_count, "cluster_name": titles[cluster_count], "sub_cluster_names": cluster_names_modify_list})
else:
sub_cluster_names.append({"id": cluster_count, "cluster_name": titles[cluster_count], "sub_cluster_names": []})
# save cluster names
if not output_file.endswith(".json"):
output_file = output_file + ".json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(sub_cluster_names, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--output_file", type=str, required=True)
args = parser.parse_args()
# extracting topics
main(args.input_file, args.output_file)