199 lines
6.5 KiB
Python
199 lines
6.5 KiB
Python
import argparse
|
|
import pandas as pd
|
|
from transformers import AutoModel
|
|
from sklearn.cluster import KMeans
|
|
from sklearn.metrics import silhouette_score
|
|
from hazm import Normalizer
|
|
from tqdm import tqdm
|
|
import requests
|
|
from openai import OpenAI
|
|
import httpx
|
|
import random
|
|
|
|
from post_cluster import PostClusterLLM
|
|
from topic_recreation import TopicRecreation
|
|
|
|
|
|
START_K = 20
|
|
END_K = 60
|
|
|
|
|
|
def get_best_k(embeddings):
|
|
|
|
max_sil_score = 0
|
|
best_k = START_K
|
|
for k in range(START_K, min(END_K, len(embeddings))):
|
|
kmeans = KMeans(n_clusters=k, random_state=42, n_init="auto")
|
|
labels = kmeans.fit_predict(embeddings)
|
|
|
|
sil_score = silhouette_score(embeddings, labels)
|
|
if sil_score > max_sil_score:
|
|
max_sil_score = sil_score
|
|
best_k = k
|
|
|
|
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=10)
|
|
labels = kmeans.fit_predict(embeddings)
|
|
|
|
return best_k, labels
|
|
|
|
|
|
def get_embeddings(names):
|
|
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to("cuda")
|
|
|
|
normalizer = Normalizer()
|
|
names = [normalizer.normalize(name) for name in names]
|
|
|
|
adjs = ["توهین", "انتقاد", "نقد", "حمایت", "مسائل", "مربوط", "تهدید", "عملکرد", "رفتار", "به", "از", "در"]
|
|
|
|
names_new = []
|
|
for name in names:
|
|
for adj in adjs:
|
|
name = name.replace(adj, "")
|
|
names_new.append(name)
|
|
|
|
embeddings = []
|
|
for batch in tqdm(range(0, len(names_new), 50)):
|
|
embeddings += model.encode(names_new[batch:batch+50], task="separation").tolist()
|
|
|
|
return embeddings
|
|
|
|
|
|
def get_cluster_names(clusters):
|
|
headers = {"Content-Type": "application/json",}
|
|
|
|
prompt = """
|
|
You are a helpful assistant that generates names for clusters of trends in persian.
|
|
I will give you a list of trends and you will generate a name for this cluster.
|
|
There might be some different topics in the list so you just consider the dominant topic.
|
|
Just give me the final answer in persian.
|
|
"""
|
|
|
|
cluster_names = []
|
|
for data in clusters:
|
|
cluster_samples = random.sample(data, min(20, len(data)))
|
|
|
|
messages = [{"role": "system", "content": prompt}, {"role": "user", "content": str(cluster_samples)}]
|
|
|
|
payload = {
|
|
"model": "google/gemma-3-27b-it",
|
|
"messages": messages,
|
|
"max_tokens": 8000
|
|
}
|
|
|
|
response = requests.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload)
|
|
our_response = response.json()['choices'][0]['message']['content']
|
|
cluster_names.append(our_response)
|
|
|
|
return cluster_names
|
|
|
|
|
|
def modify_cluster_names(cluster_names):
|
|
PROXY_URL = "http://2zajDvJvJg:e0BtBiynhF@192.168.130.40:51371/"
|
|
http_client = httpx.Client(proxy=PROXY_URL)
|
|
client = OpenAI(api_key="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", http_client=http_client)
|
|
|
|
prompt = """
|
|
You are a topic modification expert.
|
|
|
|
I will give you a list of topics.
|
|
|
|
## TASK
|
|
Extract meaningful and distinct topics from the list. you can chnage the name of topics. Just about 20-30 topics that cover all of them.
|
|
|
|
## RULES
|
|
- You can combine or split or ... for doing this task.
|
|
- You can change the name of topics to make it more general or more specific.
|
|
- the final topics must be distinct and have specific meaning rather than others.
|
|
- dont combine topics that are not related to each other. like economical with political with social with ...
|
|
- combine topics that are related to each other. like ghaza with palestine or ...
|
|
|
|
## MUST
|
|
- all categories must be distinct and have specific meaning from other categories.
|
|
- two categories can not be similar to each other.
|
|
|
|
I will trust your intelligence.
|
|
write the final answer in persian.
|
|
"""
|
|
|
|
response = client.chat.completions.create(
|
|
model="o3",
|
|
messages=[
|
|
{"role": "system", "content": prompt},
|
|
{"role": "user", "content": str(cluster_names)}
|
|
]
|
|
)
|
|
out = response.choices[0].message.content
|
|
|
|
return out
|
|
|
|
|
|
def main(input_file, output_file):
|
|
# read input file
|
|
df = pd.read_excel(input_file)
|
|
topics = df["topic_recreation"].tolist()
|
|
|
|
# get embeddings
|
|
embeddings = get_embeddings(topics)
|
|
|
|
# get best k and labels of kmeans with best_k
|
|
best_k, labels = get_best_k(embeddings)
|
|
|
|
# fill clusters
|
|
clusters = []
|
|
for i in range(best_k):
|
|
clusters.append([])
|
|
|
|
for i in range(len(clusters)):
|
|
for topic, label in zip(topics, labels):
|
|
if label == i:
|
|
clusters[i].append(topic)
|
|
|
|
# get cluster names
|
|
cluster_names = get_cluster_names(clusters)
|
|
|
|
# get embeddings for cluster names
|
|
cluster_names_embeddings = get_embeddings(cluster_names)
|
|
|
|
# get best k and labels of kmeans with best_k
|
|
best_k_cluster_names, labels_cluster_names = get_best_k(cluster_names_embeddings)
|
|
|
|
# fill clusters of cluster_names
|
|
clusters_cluster_names = []
|
|
for i in range(best_k_cluster_names):
|
|
clusters_cluster_names.append([])
|
|
|
|
for i in range(len(clusters_cluster_names)):
|
|
for cluster_name, label in zip(cluster_names, labels_cluster_names):
|
|
if label == i:
|
|
clusters_cluster_names[i].append(cluster_name)
|
|
|
|
# get cluster names for clusters of cluster_names
|
|
cluster_names_modify = modify_cluster_names(clusters_cluster_names)
|
|
|
|
# save cluster names
|
|
with open(output_file, "w") as f:
|
|
for count, cluster_name in enumerate(cluster_names_modify):
|
|
if count == len(cluster_names_modify) - 1:
|
|
f.write(cluster_name)
|
|
else:
|
|
f.write(cluster_name + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input_file", type=str, required=True)
|
|
parser.add_argument("--output_file", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
# apply topic_recreation
|
|
topic_recreation = TopicRecreation()
|
|
topic_file = args.output_file.replace(".xlsx", "_topic_recreation.xlsx")
|
|
topic_recreation.start_process(args.input_file, topic_file)
|
|
|
|
# extracting topics
|
|
titles_file = args.output_file.replace(".xlsx", "_titles.txt")
|
|
main(topic_file, titles_file)
|
|
|
|
# apply clustering
|
|
post_cluster = PostClusterLLM()
|
|
post_cluster.start_process(topics_file, args.output_file) |