285 lines
9.3 KiB
Python
285 lines
9.3 KiB
Python
import argparse
|
|
import pandas as pd
|
|
from transformers import AutoModel
|
|
from sklearn.cluster import KMeans
|
|
from sklearn.metrics import silhouette_score
|
|
from hazm import Normalizer
|
|
from tqdm import tqdm
|
|
import requests
|
|
from openai import OpenAI
|
|
import httpx
|
|
import random
|
|
import re
|
|
import json
|
|
|
|
|
|
START_K = 2
|
|
END_K = 60
|
|
|
|
|
|
def sanitize_for_excel(text):
|
|
"""Remove zero-width and bidi control characters that can confuse Excel rendering."""
|
|
if text is None:
|
|
return ""
|
|
s = str(text)
|
|
# Characters to remove: ZWNJ, ZWJ, RLM, LRM, RLE, LRE, PDF, BOM, Tatweel
|
|
remove_chars = [
|
|
"\u200c", # ZWNJ
|
|
"\u200d", # ZWJ
|
|
"\u200e", # LRM
|
|
"\u200f", # RLM
|
|
"\u202a", # LRE
|
|
"\u202b", # RLE
|
|
"\u202c", # PDF
|
|
"\u202d", # LRO
|
|
"\u202e", # RLO
|
|
"\ufeff", # BOM
|
|
"\u0640", # Tatweel
|
|
]
|
|
for ch in remove_chars:
|
|
s = s.replace(ch, "")
|
|
# Normalize whitespace
|
|
s = re.sub(r"\s+", " ", s).strip()
|
|
return s
|
|
|
|
|
|
def get_best_k(embeddings):
|
|
|
|
max_sil_score = 0
|
|
best_k = START_K
|
|
for k in range(START_K, min(END_K, len(embeddings))):
|
|
kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
|
|
labels = kmeans.fit_predict(embeddings)
|
|
|
|
sil_score = silhouette_score(embeddings, labels)
|
|
if sil_score > max_sil_score:
|
|
max_sil_score = sil_score
|
|
best_k = k
|
|
|
|
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=10)
|
|
labels = kmeans.fit_predict(embeddings)
|
|
|
|
return best_k, labels
|
|
|
|
|
|
def get_embeddings(names):
|
|
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to("cuda")
|
|
|
|
normalizer = Normalizer()
|
|
names = [normalizer.normalize(name) for name in names]
|
|
|
|
adjs = ["توهین", "انتقاد", "نقد", "حمایت", "مسائل", "مربوط", "تهدید", "عملکرد", "رفتار", "به", "از", "در"]
|
|
|
|
names_new = []
|
|
for name in names:
|
|
for adj in adjs:
|
|
name = name.replace(adj, "")
|
|
names_new.append(name)
|
|
|
|
embeddings = []
|
|
for batch in tqdm(range(0, len(names_new), 50)):
|
|
embeddings += model.encode(names_new[batch:batch+50], task="separation").tolist()
|
|
|
|
return embeddings
|
|
|
|
|
|
def get_cluster_names(clusters):
|
|
headers = {"Content-Type": "application/json",}
|
|
|
|
prompt = """
|
|
You are a helpful assistant that generates names for clusters of topics in persian.
|
|
I will give you a list of topics and you will generate a name for this cluster.
|
|
There might be some different topics in the list so you just consider the dominant topic.
|
|
be specific about the cluster name.
|
|
Just give me the final answer in persian.
|
|
"""
|
|
|
|
cluster_names = []
|
|
for data in clusters:
|
|
|
|
if len(data) < 10:
|
|
continue
|
|
|
|
cluster_samples = random.sample(data, min(20, len(data)))
|
|
|
|
messages = [{"role": "system", "content": prompt}, {"role": "user", "content": str(cluster_samples)}]
|
|
|
|
payload = {
|
|
"model": "google/gemma-3-27b-it",
|
|
"messages": messages,
|
|
"max_tokens": 8000
|
|
}
|
|
|
|
response = requests.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload)
|
|
our_response = response.json()['choices'][0]['message']['content']
|
|
cluster_names.append(our_response)
|
|
|
|
return cluster_names
|
|
|
|
|
|
def modify_cluster_names(cluster_names, title, best_k):
|
|
PROXY_URL = "http://2zajDvJvJg:e0BtBiynhF@192.168.130.40:51371/"
|
|
http_client = httpx.Client(proxy=PROXY_URL)
|
|
client = OpenAI(api_key="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", http_client=http_client)
|
|
|
|
start = (best_k / 2) - ((best_k / 2) % 10)
|
|
if start == 0:
|
|
start = 1
|
|
|
|
prompt = f"""
|
|
You are a sub category modification expert.
|
|
|
|
I will give you a list of topics.
|
|
|
|
all these topics belongs to {title} category
|
|
|
|
## TASK
|
|
Extract meaningful and distinct sub category from the list. you can change the name of topics. Just about {start}-{start+10} topics that cover all of them.
|
|
|
|
## RULES
|
|
- You can combine or split or ... for doing this task.
|
|
- You can change the name of topics to make it more general or more specific.
|
|
- the final topics must be distinct and have specific meaning rather than others.
|
|
- dont combine topics that are not related to each other. like economical with political with social with ...
|
|
- combine topics that are related to each other. like ghaza with palestine or ...
|
|
|
|
## MUST
|
|
- all sub categories must be distinct and have specific meaning from other categories.
|
|
- two categories can not be similar to each other.
|
|
- be specifc about sub categories
|
|
|
|
I will trust your intelligence.
|
|
write the final answer in persian.
|
|
"""
|
|
|
|
response = client.chat.completions.create(
|
|
model="o3",
|
|
messages=[
|
|
{"role": "system", "content": prompt},
|
|
{"role": "user", "content": str(cluster_names)}
|
|
]
|
|
)
|
|
out = response.choices[0].message.content
|
|
|
|
return out
|
|
|
|
|
|
def extract_list(text, count):
|
|
|
|
headers = {"Content-Type": "application/json",}
|
|
|
|
prompt = """
|
|
extract the titles from this text and put it in a list.
|
|
just return the output in list format, do not include any other text : ["title_1", "title_2", ...]
|
|
"""
|
|
|
|
messages = [{"role": "system", "content": prompt}, {"role": "user", "content": text}]
|
|
|
|
payload = {
|
|
"model": "google/gemma-3-27b-it",
|
|
"messages": messages,
|
|
"max_tokens": 8000
|
|
}
|
|
|
|
response = requests.post("http://192.168.130.206:4001/v1/chat/completions", headers=headers, json=payload)
|
|
out = response.json()['choices'][0]['message']['content']
|
|
try:
|
|
out = json.loads(out)
|
|
except:
|
|
print(f"error in extract list {count}")
|
|
return out
|
|
|
|
|
|
def main(input_file, output_file):
|
|
# read input file
|
|
df = pd.read_excel(input_file)
|
|
topics = df["topic"].tolist()
|
|
cluster_llms = df["cluster_llm"].tolist()
|
|
|
|
# get embeddings
|
|
embeddings = get_embeddings(topics)
|
|
|
|
# extract main cluster names
|
|
cluster_names = []
|
|
with open("titles_o3.txt", "r") as f:
|
|
titles = f.readlines()
|
|
|
|
titles = [sanitize_for_excel(title.strip()) for title in titles]
|
|
|
|
embedding_cluster = []
|
|
best_k = len(titles)
|
|
for i in range(best_k):
|
|
embedding_cluster.append([])
|
|
|
|
topic_cluster = []
|
|
best_k = len(titles)
|
|
for i in range(best_k):
|
|
topic_cluster.append([])
|
|
|
|
for m in range(len(titles)):
|
|
for embedding, cluster_name, topic in zip(embeddings, cluster_llms, topics):
|
|
if cluster_name == titles[m]:
|
|
embedding_cluster[m].append(embedding)
|
|
topic_cluster[m].append(topic)
|
|
|
|
sub_cluster_names = []
|
|
for cluster_count in tqdm(range(len(titles))):
|
|
print(f"start {cluster_count} \n")
|
|
# get best k and labels of kmeans with best_k
|
|
best_k, labels = get_best_k(embedding_cluster[cluster_count])
|
|
print(f"initial best_k {best_k}\n")
|
|
|
|
# fill clusters
|
|
clusters = []
|
|
for i in range(best_k):
|
|
clusters.append([])
|
|
|
|
for i in range(len(clusters)):
|
|
for topic, label in zip(topic_cluster[cluster_count], labels):
|
|
if label == i:
|
|
clusters[i].append(topic)
|
|
|
|
# get cluster names
|
|
cluster_names = get_cluster_names(clusters)
|
|
|
|
if len(cluster_names) > 1:
|
|
# get embeddings for cluster names
|
|
cluster_names_embeddings = get_embeddings(cluster_names)
|
|
|
|
# get best k and labels of kmeans with best_k
|
|
best_k_cluster_names, labels_cluster_names = get_best_k(cluster_names_embeddings)
|
|
print(f"second best_k {best_k_cluster_names}\n")
|
|
|
|
# fill clusters of cluster_names
|
|
clusters_cluster_names = []
|
|
for i in range(best_k_cluster_names):
|
|
clusters_cluster_names.append([])
|
|
|
|
for i in range(len(clusters_cluster_names)):
|
|
for cluster_name, label in zip(cluster_names, labels_cluster_names):
|
|
if label == i:
|
|
clusters_cluster_names[i].append(cluster_name)
|
|
|
|
# get cluster names for clusters of cluster_names
|
|
cluster_names_modify = modify_cluster_names(clusters_cluster_names, titles[cluster_count], best_k)
|
|
cluster_names_modify_list = extract_list(cluster_names_modify, cluster_count)
|
|
sub_cluster_names.append({"id": cluster_count, "cluster_name": titles[cluster_count], "sub_cluster_names": cluster_names_modify_list})
|
|
|
|
else:
|
|
sub_cluster_names.append({"id": cluster_count, "cluster_name": titles[cluster_count], "sub_cluster_names": []})
|
|
|
|
# save cluster names
|
|
if not output_file.endswith(".json"):
|
|
output_file = output_file + ".json"
|
|
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
json.dump(sub_cluster_names, f, ensure_ascii=False, indent=2)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input_file", type=str, required=True)
|
|
parser.add_argument("--output_file", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
# extracting topics
|
|
main(args.input_file, args.output_file) |