text_clustering/clustering_pipeline.py
2025-10-21 11:14:59 +03:30

199 lines
6.5 KiB
Python

import argparse
import pandas as pd
from transformers import AutoModel
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from hazm import Normalizer
from tqdm import tqdm
import requests
from openai import OpenAI
import httpx
import random
from post_cluster import PostClusterLLM
from topic_recreation import TopicRecreation
START_K = 20
END_K = 60
def get_best_k(embeddings):
max_sil_score = 0
best_k = START_K
for k in range(START_K, min(END_K, len(embeddings))):
kmeans = KMeans(n_clusters=k, random_state=42, n_init="auto")
labels = kmeans.fit_predict(embeddings)
sil_score = silhouette_score(embeddings, labels)
if sil_score > max_sil_score:
max_sil_score = sil_score
best_k = k
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=10)
labels = kmeans.fit_predict(embeddings)
return best_k, labels
def get_embeddings(names):
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to("cuda")
normalizer = Normalizer()
names = [normalizer.normalize(name) for name in names]
adjs = ["توهین", "انتقاد", "نقد", "حمایت", "مسائل", "مربوط", "تهدید", "عملکرد", "رفتار", "به", "از", "در"]
names_new = []
for name in names:
for adj in adjs:
name = name.replace(adj, "")
names_new.append(name)
embeddings = []
for batch in tqdm(range(0, len(names_new), 50)):
embeddings += model.encode(names_new[batch:batch+50], task="separation").tolist()
return embeddings
def get_cluster_names(clusters):
headers = {"Content-Type": "application/json",}
prompt = """
You are a helpful assistant that generates names for clusters of trends in persian.
I will give you a list of trends and you will generate a name for this cluster.
There might be some different topics in the list so you just consider the dominant topic.
Just give me the final answer in persian.
"""
cluster_names = []
for data in clusters:
cluster_samples = random.sample(data, min(20, len(data)))
messages = [{"role": "system", "content": prompt}, {"role": "user", "content": str(cluster_samples)}]
payload = {
"model": "google/gemma-3-27b-it",
"messages": messages,
"max_tokens": 8000
}
response = requests.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload)
our_response = response.json()['choices'][0]['message']['content']
cluster_names.append(our_response)
return cluster_names
def modify_cluster_names(cluster_names):
PROXY_URL = "http://2zajDvJvJg:e0BtBiynhF@192.168.130.40:51371/"
http_client = httpx.Client(proxy=PROXY_URL)
client = OpenAI(api_key="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", http_client=http_client)
prompt = """
You are a topic modification expert.
I will give you a list of topics.
## TASK
Extract meaningful and distinct topics from the list. you can chnage the name of topics. Just about 20-30 topics that cover all of them.
## RULES
- You can combine or split or ... for doing this task.
- You can change the name of topics to make it more general or more specific.
- the final topics must be distinct and have specific meaning rather than others.
- dont combine topics that are not related to each other. like economical with political with social with ...
- combine topics that are related to each other. like ghaza with palestine or ...
## MUST
- all categories must be distinct and have specific meaning from other categories.
- two categories can not be similar to each other.
I will trust your intelligence.
write the final answer in persian.
"""
response = client.chat.completions.create(
model="o3",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": str(cluster_names)}
]
)
out = response.choices[0].message.content
return out
def main(input_file, output_file):
# read input file
df = pd.read_excel(input_file)
topics = df["topic_recreation"].tolist()
# get embeddings
embeddings = get_embeddings(topics)
# get best k and labels of kmeans with best_k
best_k, labels = get_best_k(embeddings)
# fill clusters
clusters = []
for i in range(best_k):
clusters.append([])
for i in range(len(clusters)):
for topic, label in zip(topics, labels):
if label == i:
clusters[i].append(topic)
# get cluster names
cluster_names = get_cluster_names(clusters)
# get embeddings for cluster names
cluster_names_embeddings = get_embeddings(cluster_names)
# get best k and labels of kmeans with best_k
best_k_cluster_names, labels_cluster_names = get_best_k(cluster_names_embeddings)
# fill clusters of cluster_names
clusters_cluster_names = []
for i in range(best_k_cluster_names):
clusters_cluster_names.append([])
for i in range(len(clusters_cluster_names)):
for cluster_name, label in zip(cluster_names, labels_cluster_names):
if label == i:
clusters_cluster_names[i].append(cluster_name)
# get cluster names for clusters of cluster_names
cluster_names_modify = modify_cluster_names(clusters_cluster_names)
# save cluster names
with open(output_file, "w") as f:
for count, cluster_name in enumerate(cluster_names_modify):
if count == len(cluster_names_modify) - 1:
f.write(cluster_name)
else:
f.write(cluster_name + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--output_file", type=str, required=True)
args = parser.parse_args()
# apply topic_recreation
topic_recreation = TopicRecreation()
topic_file = args.output_file.replace(".xlsx", "_topic_recreation.xlsx")
topic_recreation.start_process(args.input_file, topic_file)
# extracting topics
titles_file = args.output_file.replace(".xlsx", "_titles.txt")
main(topic_file, titles_file)
# apply clustering
post_cluster = PostClusterLLM()
post_cluster.start_process(topics_file, args.output_file)