add preprocess v1
This commit is contained in:
parent
738d120728
commit
826be2a19e
7
.gitignore
vendored
7
.gitignore
vendored
@ -1 +1,6 @@
|
||||
data_preprocess/data/*
|
||||
data_preprocess/data/*
|
||||
data
|
||||
*/__pycache__/*
|
||||
.env
|
||||
.venv
|
||||
*.json
|
||||
@ -1,152 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "a78759c8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/datasets/load.py:1461: FutureWarning: The repository for Gholamreza/pquad contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Gholamreza/pquad\n",
|
||||
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
|
||||
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
|
||||
" warnings.warn(\n",
|
||||
"Downloading builder script: 4.41kB [00:00, 4.07MB/s]\n",
|
||||
"Downloading readme: 5.15kB [00:00, 7.92MB/s]\n",
|
||||
"Downloading data: 100%|██████████| 26.4M/26.4M [01:05<00:00, 406kB/s] \n",
|
||||
"Downloading data: 100%|██████████| 3.49M/3.49M [00:00<00:00, 5.18MB/s]\n",
|
||||
"Downloading data: 100%|██████████| 3.45M/3.45M [00:00<00:00, 5.38MB/s]\n",
|
||||
"Generating train split: 0%| | 0/63994 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/root/.cache/huggingface/datasets/downloads/e49d5f650d69a5999fe6ceb4438a023cccdcf3e6519abc4dabce736f91595591\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Generating train split: 100%|██████████| 63994/63994 [00:02<00:00, 21411.84 examples/s]\n",
|
||||
"Generating validation split: 21%|██▏ | 1703/7976 [00:00<00:00, 16945.09 examples/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/root/.cache/huggingface/datasets/downloads/ea42ddfa9db6f39bc3249a878c853a6f6b466f6217a360bbb8afbac9410d84cc\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Generating validation split: 100%|██████████| 7976/7976 [00:00<00:00, 23678.57 examples/s]\n",
|
||||
"Generating test split: 18%|█▊ | 1434/8002 [00:00<00:00, 10262.32 examples/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/root/.cache/huggingface/datasets/downloads/d6ba3b80ff2a6d0333454fac286694b5e777518ea141e0dcd7c0558b71624882\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Generating test split: 100%|██████████| 8002/8002 [00:00<00:00, 20511.40 examples/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"dataset = load_dataset(\"Gholamreza/pquad\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "c91f659a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"48273\n",
|
||||
"63994\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all_dataset = []\n",
|
||||
"for data in dataset[\"train\"]:\n",
|
||||
" if len(data[\"answers\"][\"text\"]) > 0:\n",
|
||||
" all_dataset.append({'question': data['question'], 'passgae_positive': [data['context']], 'passgae_negative': []})\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(len(all_dataset))\n",
|
||||
"print(len(dataset[\"train\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "d66809ce",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'question': 'در 816 مرعشی از حکمرانی چه کسی در تنکابن یاد کرده\\u200cاست؟', 'passgae_positive': ['در ۸۰۶ خواندمیر به ولایت تنکابن اشاره کرده و در ۸۱۶ مرعشی از حکمرانیِ سیدداوود کارکیای تنکابنی، فرزند سیدهادی کیا، در تنکابن یاد کرده\\u200cاست. مَلک کیومرث ــ که در ۸۳۰ به مخالفت با سادات گیلان برخاسته بود ــ در ۸۳۱ عمارت خاصة سید داوود کارکیای تنکابنی را که در اواخر تابستان هنوز در ییلاق به سر می\\u200cبرد، آتش زد و برخی اهالی را به قتل رساند. در ۸۶۵ مازندرانی از «موضع تنکابن» در «مملکت گیلان» نام برده\\u200cاست. مرعشی در ۸۸۹ به حرکت خود از کِلیشُم (از قرای ییلاقی تنکابن) به تنکابن برای تصرف «دشت تنکابن» اشاره کرده\\u200cاست.'], 'passgae_negative': []}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(all_dataset[10000])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9a566e69",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@ -1,235 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "a78759c8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1000212\n",
|
||||
"250666\n",
|
||||
"270642\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"dataset_synthetic_scores = []\n",
|
||||
"with open('/home/firouzi/embedding_model/data_preprocess/data/synthetic-persian-qa-retrieval/train.jsonl', 'r', encoding='utf-8') as f:\n",
|
||||
" for line in f:\n",
|
||||
" data = json.loads(line)\n",
|
||||
" dataset_synthetic_scores.append(data)\n",
|
||||
"\n",
|
||||
"dataset_synthetic_queries = {}\n",
|
||||
"with open('/home/firouzi/embedding_model/data_preprocess/data/synthetic-persian-qa-retrieval/queries.jsonl', 'r', encoding='utf-8') as f:\n",
|
||||
" for line in f:\n",
|
||||
" json_data = json.loads(line)\n",
|
||||
" dataset_synthetic_queries[json_data['_id']] = json_data\n",
|
||||
"\n",
|
||||
"dataset_synthetic_corpus = {}\n",
|
||||
"with open('/home/firouzi/embedding_model/data_preprocess/data/synthetic-persian-qa-retrieval/corpus.jsonl', 'r', encoding='utf-8') as f:\n",
|
||||
" for line in f:\n",
|
||||
" json_data = json.loads(line)\n",
|
||||
" dataset_synthetic_corpus[json_data['_id']] = json_data\n",
|
||||
"\n",
|
||||
"print(len(dataset_synthetic_scores))\n",
|
||||
"print(len(dataset_synthetic_queries))\n",
|
||||
"print(len(dataset_synthetic_corpus))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "bbb2657f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"106520\n",
|
||||
"223423\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all_dataset = {}\n",
|
||||
"count = 0\n",
|
||||
"for data_topic in dataset_synthetic_scores:\n",
|
||||
" try:\n",
|
||||
" query_id = data_topic['query-id']\n",
|
||||
" corpus_id = int(data_topic['corpus-id'])\n",
|
||||
" score = data_topic['score']\n",
|
||||
"\n",
|
||||
" passgae_positive = []\n",
|
||||
" passgae_negative = []\n",
|
||||
" if score == \"1\":\n",
|
||||
" passgae_positive.append({'title': dataset_synthetic_corpus[corpus_id]['title'].replace('\\u200c', ' '), 'text': dataset_synthetic_corpus[corpus_id]['text'].replace('\\u200c', ' ')})\n",
|
||||
" if all_dataset.get(query_id, None):\n",
|
||||
" all_dataset[query_id]['passgae_positive'].append({'title': dataset_synthetic_corpus[corpus_id]['title'].replace('\\u200c', ' '), 'text': dataset_synthetic_corpus[corpus_id]['text'].replace('\\u200c', ' ')})\n",
|
||||
" else:\n",
|
||||
" all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'], 'passgae_positive': passgae_positive, 'passgae_negative': passgae_negative}\n",
|
||||
" else:\n",
|
||||
" passgae_negative.append({'title': dataset_synthetic_corpus[corpus_id]['title'].replace('\\u200c', ' '), 'text': dataset_synthetic_corpus[corpus_id]['text'].replace('\\u200c', ' ')})\n",
|
||||
" if all_dataset.get(query_id, None):\n",
|
||||
" all_dataset[query_id]['passgae_negative'].append({'title': dataset_synthetic_corpus[corpus_id]['title'].replace('\\u200c', ' '), 'text': dataset_synthetic_corpus[corpus_id]['text'].replace('\\u200c', ' ')})\n",
|
||||
" else:\n",
|
||||
" all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'], 'passgae_positive': passgae_positive, 'passgae_negative': passgae_negative}\n",
|
||||
" except:\n",
|
||||
" count += 1\n",
|
||||
"print(count)\n",
|
||||
"print(len(all_dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "42166e97",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'query-id': 'train_2', 'corpus-id': '32409', 'score': '0'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data_topic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c91f659a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'query-id': 'train_0', 'corpus-id': '43272', 'score': '1'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataset_synthetic_scores[0]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d66809ce",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'_id': 'test_0',\n",
|
||||
" 'text': 'چگونه نان کدو حلوایی را در فر بپزیم و چه نکاتی برای پخت بهتر وجود دارد؟'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataset_synthetic_queries[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "1cdb5b31",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'_id': 10,\n",
|
||||
" 'title': '',\n",
|
||||
" 'text': 'عبدالرحمن رحمانی یک سیاستمدار افغانستانی است که در دوره شانزدهم مجلس نمایندگان به عنوان نماینده مردم بلخ فعالیت می\\u200cکند. او در این مجلس عضو کمیسیون اقتصاد ملی، سازمان\\u200cهای غیر حکومتی، انکشاف دهات، زراعت و مالداری می\\u200cباشد.'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataset_synthetic_corpus[10]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "e6b8c9af",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'question': 'کتاب «تأثیر فلات زدگی شغلی در سازمان» چه کسانی را هدف قرار داده است؟',\n",
|
||||
" 'passgae_positive': [{'title': '',\n",
|
||||
" 'text': '\"مدیران کسب و کارها\" و \"رهبران تیم ها یا سازمان\\u200cها\" مخاطبان اصلی این کتاب هستند. با مطالعه این اثر می توانند بهتر با موضوع فلات زدگی آشنا شوند، آن را در میان کارکنان خود تشخیص دهند و راه حل هایی برای بهبود عملکرد آنها ارائه کنند.'}],\n",
|
||||
" 'passgae_negative': [{'title': '',\n",
|
||||
" 'text': 'این کتاب به موضوع مدیریت و رهبری اختصاص دارد که توسط پیتر اف. دراکر و جوزف اِی. ماچیاری\\u200cالو نوشته شده است. تمرکز اصلی این کتاب بر مفهوم «انجام دادن کار درست» یا اثربخشی، در مقابل صرفاً انجام صحیح امور (کارایی) است. نویسندگان تأکید می کنند که سازمان ها باید فراتر از صرفاً کارآمد بودن حرکت کرده و اطمینان حاصل کنند که ارزش واقعی برای مشتریان نهایی خلق می شود.'},\n",
|
||||
" {'title': '',\n",
|
||||
" 'text': 'اگر در سازمان یا کسب و کاری فعالیت می کنید که از چنین اصطلاحاتی استفاده می کند، این کتاب به شما کمک خواهد کرد تا ماهیت آنها را شناخته و از تاثیر منفی شان بر دوری جویید. همچنین با خواندن این کتاب متوجه خواهید شد که چگونه می توان گفتمان سازمانی را تغییر داده و سازمان خود را نجات داد.'},\n",
|
||||
" {'title': '',\n",
|
||||
" 'text': 'این کتاب به علاقه\\u200cمندان مدیریت و افرادی که در حوزه\\u200cهای مرتبط با منابع انسانی و بهبود محیط کار فعالیت می\\u200cکنند، پیشنهاد می\\u200cشود. مطالعه این کتاب می\\u200cتواند به مدیران و کارکنان کمک کند تا با شناخت بهتر استرس شغلی و راهکارهای مدیریت آن، به بهبود کیفیت کار و افزایش رضایت شغلی در محیط\\u200cهای کاری بپردازند.'}]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"list(all_dataset.values())[14500]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "74ef02a1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
95
data_preprocess/generate_random_negative_sample.py
Normal file
95
data_preprocess/generate_random_negative_sample.py
Normal file
@ -0,0 +1,95 @@
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from data_preprocess.text_embedder import TextEmbedder
|
||||
|
||||
|
||||
THRESHOLD_MULTIPLY = 0.9
|
||||
RANDOM_NEGATIVE_COUNT = 25
|
||||
batch_size = 100
|
||||
|
||||
text_embedder = TextEmbedder()
|
||||
|
||||
|
||||
def generate_random_negative_sample(all_dataset):
|
||||
"""
|
||||
generate random negative sample from dataset
|
||||
Args:
|
||||
dataset: list of dicts
|
||||
Returns:
|
||||
dataset: list of dicts
|
||||
"""
|
||||
len_dataset = len(all_dataset)
|
||||
all_dataset_embeddings = [{'question_embedding': "", 'passage_positive_embedding': [], 'passage_negative_embeddings': []} for _ in range(len_dataset)]
|
||||
|
||||
print("calculate question embeddings")
|
||||
# calculate question embeddings
|
||||
for i in tqdm(range(0, len_dataset, batch_size)):
|
||||
|
||||
question_list = []
|
||||
for id in range(i, min(i + batch_size, len_dataset)):
|
||||
question_list.append(all_dataset[id]['question'])
|
||||
|
||||
question_embeddings = text_embedder.embed_texts(question_list)
|
||||
|
||||
count = 0
|
||||
for id in range(i, min(i + batch_size, len_dataset)):
|
||||
all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count]
|
||||
count += 1
|
||||
|
||||
|
||||
print("calculate passage positive embeddings")
|
||||
# calculate passage positive embeddings
|
||||
for i in tqdm(range(0, len_dataset, batch_size)):
|
||||
|
||||
passage_positive_list = []
|
||||
for id in range(i, min(i + batch_size, len_dataset)):
|
||||
for passage in all_dataset[id]['passage_positive']:
|
||||
passage_positive_list.append(passage)
|
||||
|
||||
passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list)
|
||||
|
||||
count = 0
|
||||
for id in range(i, min(i + batch_size, len_dataset)):
|
||||
for passage_id in range(len(all_dataset[id]['passage_positive'])):
|
||||
all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count])
|
||||
count += 1
|
||||
|
||||
print("calculate passage negative embeddings")
|
||||
# calculate passage negative embeddings
|
||||
for i in tqdm(range(0, len_dataset, batch_size)):
|
||||
|
||||
passage_negative_list = []
|
||||
for id in range(i, min(i + batch_size, len_dataset)):
|
||||
for passage in all_dataset[id]['passage_negative']:
|
||||
passage_negative_list.append(passage)
|
||||
|
||||
passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list)
|
||||
|
||||
count = 0
|
||||
for id in range(i, min(i + batch_size, len_dataset)):
|
||||
for passage_id in range(len(all_dataset[id]['passage_negative'])):
|
||||
all_dataset_embeddings[id]['passage_negative_embeddings'].append(passage_negative_embeddings[count])
|
||||
count += 1
|
||||
|
||||
print("getting random negative passages")
|
||||
for id in tqdm(range(len_dataset)):
|
||||
question_embeddings = all_dataset_embeddings[id]['question_embedding']
|
||||
passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0]
|
||||
|
||||
score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings)
|
||||
|
||||
while len(all_dataset[id]['passage_negative_random']) < RANDOM_NEGATIVE_COUNT:
|
||||
random_id = random.randint(0, len_dataset - 1)
|
||||
if random_id != id:
|
||||
all_passages_embedding = all_dataset_embeddings[random_id]['passage_negative_embeddings'] + all_dataset_embeddings[random_id]['passage_positive_embedding']
|
||||
all_passages = all_dataset[random_id]['passage_negative'] + all_dataset[random_id]['passage_positive']
|
||||
random_passage_id = random.randint(0, len(all_passages) - 1)
|
||||
random_passage_embeddings = all_passages_embedding[random_passage_id]
|
||||
score_question_random_passage = np.dot(question_embeddings, random_passage_embeddings)
|
||||
|
||||
if score_question_random_passage < THRESHOLD_MULTIPLY * score_question_passage_positive:
|
||||
all_dataset[id]['passage_negative_random'].append(all_passages[random_passage_id])
|
||||
|
||||
return all_dataset
|
||||
217
data_preprocess/preprocess_v1.py
Normal file
217
data_preprocess/preprocess_v1.py
Normal file
@ -0,0 +1,217 @@
|
||||
import argparse
|
||||
from datasets import load_dataset
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
from data_preprocess.remove_false_negative_model import LLMModel
|
||||
from data_preprocess.generate_random_negative_sample import generate_random_negative_sample
|
||||
|
||||
|
||||
llm_model = LLMModel()
|
||||
|
||||
def load_synthetic_dataset(synthetic_train_path, synthetic_queries_path, synthetic_corpus_path):
|
||||
"""
|
||||
load synthetic dataset from local jsonl files
|
||||
output:
|
||||
[{
|
||||
"question": "",
|
||||
"passgae_positive": [],
|
||||
"passgae_negative": [],
|
||||
"passage_negative_random": []
|
||||
}]
|
||||
"""
|
||||
dataset_synthetic_scores = []
|
||||
with open(synthetic_train_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
dataset_synthetic_scores.append(data)
|
||||
|
||||
dataset_synthetic_queries = {}
|
||||
with open(synthetic_queries_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
json_data = json.loads(line)
|
||||
dataset_synthetic_queries[json_data['_id']] = json_data
|
||||
|
||||
dataset_synthetic_corpus = {}
|
||||
with open(synthetic_corpus_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
json_data = json.loads(line)
|
||||
dataset_synthetic_corpus[json_data['_id']] = json_data
|
||||
|
||||
#create a json which has question, passgae_positive, passgae_negative, passage_negative_random
|
||||
all_dataset = {}
|
||||
for data_topic in dataset_synthetic_scores:
|
||||
|
||||
query_id = data_topic['query-id']
|
||||
corpus_id = int(data_topic['corpus-id'])
|
||||
score = data_topic['score']
|
||||
|
||||
if (query_id in dataset_synthetic_queries) and (corpus_id in dataset_synthetic_corpus):
|
||||
if score == "1":
|
||||
if query_id in all_dataset:
|
||||
all_dataset[query_id]['passage_positive'].append(dataset_synthetic_corpus[corpus_id]['text'])
|
||||
else:
|
||||
all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'],
|
||||
'passage_positive': [dataset_synthetic_corpus[corpus_id]['text']],
|
||||
'passage_negative': [],
|
||||
'passage_negative_random': []}
|
||||
else:
|
||||
if query_id in all_dataset:
|
||||
all_dataset[query_id]['passage_negative'].append(dataset_synthetic_corpus[corpus_id]['text'])
|
||||
else:
|
||||
all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'],
|
||||
'passage_positive': [],
|
||||
'passage_negative': [dataset_synthetic_corpus[corpus_id]['text']],
|
||||
'passage_negative_random': []}
|
||||
|
||||
all_dataset = list(all_dataset.values())
|
||||
|
||||
return all_dataset
|
||||
|
||||
|
||||
def load_pquad_dataset():
|
||||
"""
|
||||
load pquad dataset from huggingface
|
||||
output:
|
||||
[{
|
||||
"question": "",
|
||||
"passage_positive": [],
|
||||
"passage_negative": [],
|
||||
"passage_negative_random": []
|
||||
}]
|
||||
"""
|
||||
dataset = load_dataset("Gholamreza/pquad", trust_remote_code=True)
|
||||
|
||||
all_dataset = []
|
||||
for data in dataset["train"]:
|
||||
if len(data["answers"]["text"]) > 0:
|
||||
all_dataset.append({'question': data['question'], 'passage_positive': [data['context']], 'passage_negative': [], 'passage_negative_random': []})
|
||||
|
||||
return all_dataset
|
||||
|
||||
|
||||
def remove_false_negative(dataset):
|
||||
"""
|
||||
remove false negative samples from synthetic dataset
|
||||
Args:
|
||||
dataset: list of dicts
|
||||
Returns:
|
||||
dataset: list of dicts
|
||||
"""
|
||||
# calculate passage negative embeddings
|
||||
negative_count_all = 0
|
||||
negative_count_removed = 0
|
||||
len_dataset = len(dataset)
|
||||
batch_size = 100
|
||||
for i in tqdm(range(0, len_dataset, batch_size)):
|
||||
|
||||
question_list = []
|
||||
passage_negative_list = []
|
||||
for id in range(i, min(i + batch_size, len_dataset)):
|
||||
for passage in dataset[id]['passage_negative']:
|
||||
question_list.append(dataset[id]['question'])
|
||||
passage_negative_list.append(passage)
|
||||
|
||||
results = llm_model.remove_false_negative_llm(question_list, passage_negative_list)
|
||||
|
||||
negative_count_removed += len([_ for _ in results if _ == "1"])
|
||||
negative_count_all += len(results)
|
||||
|
||||
count = 0
|
||||
for id in range(i, min(i + batch_size, len_dataset)):
|
||||
new_negative_list = []
|
||||
for passage_id in range(len(dataset[id]['passage_negative'])):
|
||||
if results[count] == "0":
|
||||
new_negative_list.append(dataset[id]['passage_negative'][passage_id])
|
||||
count += 1
|
||||
dataset[id]['passage_negative'] = new_negative_list
|
||||
|
||||
print(f"removed {negative_count_removed} false negative samples from {negative_count_all} samples")
|
||||
print("--------------------------------")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def save_dataset(dataset, output_path):
|
||||
"""
|
||||
save dataset to json file
|
||||
Args:
|
||||
dataset: list of dicts
|
||||
output_path: path to save dataset
|
||||
"""
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def main(output_path):
|
||||
|
||||
#load synthetic dataset
|
||||
print("--------------------------------")
|
||||
print("loading synthetic dataset")
|
||||
synthetic_train_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/train.jsonl"
|
||||
synthetic_corpus_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/corpus.jsonl"
|
||||
synthetic_queries_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/queries.jsonl"
|
||||
|
||||
synthetic_dataset = load_synthetic_dataset(synthetic_train_path, synthetic_queries_path, synthetic_corpus_path)
|
||||
print(f"synthetic dataset loaded : {len(synthetic_dataset)} samples")
|
||||
print("--------------------------------")
|
||||
|
||||
#load pquad dataset
|
||||
print("loading pquad dataset")
|
||||
pquad_dataset = load_pquad_dataset()
|
||||
print(f"pquad dataset loaded : {len(pquad_dataset)} samples")
|
||||
print("--------------------------------")
|
||||
|
||||
# removing false negative samples from synthetic dataset
|
||||
print("start to remove false negative samples from synthetic dataset")
|
||||
synthetic_dataset = remove_false_negative(synthetic_dataset)
|
||||
print(f"successfully removed false negative samples from synthetic dataset")
|
||||
print("--------------------------------")
|
||||
|
||||
# removing false negative samples from pquad dataset
|
||||
print("start to remove false negative samples from pquad dataset")
|
||||
pquad_dataset = remove_false_negative(pquad_dataset)
|
||||
print(f"successfully removed false negative samples from pquad dataset")
|
||||
print("--------------------------------")
|
||||
|
||||
# merge synthetic and pquad dataset
|
||||
print("start to merge synthetic and pquad dataset")
|
||||
all_dataset = synthetic_dataset + pquad_dataset
|
||||
print(f"successfully merged synthetic and pquad dataset")
|
||||
print("--------------------------------")
|
||||
|
||||
#generate random negative samples
|
||||
print("start to generate random negative samples")
|
||||
all_dataset = generate_random_negative_sample(all_dataset)
|
||||
print(f"successfully generated random negative samples")
|
||||
print("--------------------------------")
|
||||
|
||||
# save dataset
|
||||
print("start to save dataset")
|
||||
save_dataset(all_dataset, output_path)
|
||||
print(f"successfully saved dataset")
|
||||
print("--------------------------------")
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
preprocess dataset for training
|
||||
|
||||
pipelines:
|
||||
load synthetic dataset from local jsonl files
|
||||
load pquad dataset from huggingface
|
||||
remove false negative samples from synthetic dataset
|
||||
remove false negative samples from pquad dataset
|
||||
merge synthetic and pquad dataset
|
||||
generate random negative samples
|
||||
save dataset to json file
|
||||
|
||||
|
||||
python preprocess_v1.py --output_path /home/firouzi/embedding_model/data/train.json
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_path = args.output_path
|
||||
|
||||
main(output_path)
|
||||
109
data_preprocess/remove_false_negative_model.py
Normal file
109
data_preprocess/remove_false_negative_model.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time
|
||||
import re
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
model = os.getenv('LLM_AS_RERANKER_MODEL')
|
||||
model_url = os.getenv('LLM_AS_RERANKER_URL')
|
||||
model_pass = os.getenv('LLM_AS_RERANKER_PASS')
|
||||
|
||||
class LLMModel:
|
||||
def __init__(self):
|
||||
|
||||
self.instruction = """
|
||||
You are a helpful assistant that help me to find that the text is relevant to the question or not.
|
||||
You are given a question and a text.
|
||||
You must evaluate the text based on the question and return "1" if the text is relevant to the question and "0" if the text is not relevant to the question.
|
||||
|
||||
be carefull, I have chosen the text randomly from my dataset so the text must answer the question independently.
|
||||
You must return the result in the following format:
|
||||
{{"result": "1" or "0"}}
|
||||
"""
|
||||
|
||||
async def run_llm(self, session, question, text):
|
||||
"""
|
||||
Run the llm model.
|
||||
Args:
|
||||
session: The session to use for the request.
|
||||
question: The question to evaluate the text.
|
||||
text: The text to evaluate.
|
||||
Returns:
|
||||
The result of the text.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {model_pass}"}
|
||||
|
||||
input_message = f"""{{"question": "{question}", "text": "{text}"}}"""
|
||||
messages = [{"role": "system", "content": self.instruction}, {"role": "user", "content": input_message}]
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": 100
|
||||
}
|
||||
try:
|
||||
async with session.post(model_url + "/chat/completions", headers=headers, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
response = await resp.json()
|
||||
|
||||
out = response['choices'][0]['message']['content']
|
||||
|
||||
match = re.search(r'"result":\s*"?([\d\.]+)"?', out)
|
||||
|
||||
if match:
|
||||
result = match.group(1)
|
||||
|
||||
if result not in ["0", "1"]:
|
||||
print(f"Error in llm model {out}: {e}")
|
||||
return "0"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"Error in llm model {out}: {e}")
|
||||
except:
|
||||
print(f"Error in llm model: {e}")
|
||||
return "0"
|
||||
|
||||
|
||||
async def run_llm_async(self, question_list, text_list):
|
||||
"""
|
||||
Send all chunk requests concurrently.
|
||||
Args:
|
||||
question_list: The list of questions.
|
||||
text_list: The list of texts.
|
||||
Returns:
|
||||
The list of results.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [self.run_llm(session, question, text) for question, text in zip(question_list, text_list)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
|
||||
def remove_false_negative_llm(self, query_list: List[str], text_list: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remove false negative samples from the documents based on the query using the LLM model.
|
||||
Args:
|
||||
query_list: The list of queries.
|
||||
text_list: The list of texts.
|
||||
Returns:
|
||||
The list of texts that are relevant to the queries.
|
||||
"""
|
||||
if not text_list:
|
||||
return []
|
||||
|
||||
start_time = time.time()
|
||||
results = asyncio.run(self.run_llm_async(query_list, text_list))
|
||||
end_time = time.time()
|
||||
print(f"Time taken for llm model: {end_time - start_time} seconds")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
40
data_preprocess/text_embedder.py
Normal file
40
data_preprocess/text_embedder.py
Normal file
@ -0,0 +1,40 @@
|
||||
from hazm import Normalizer
|
||||
import requests
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class TextEmbedder:
|
||||
def __init__(self, model_name="BAAI/bge-m3"):
|
||||
self.model_name = model_name
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {os.getenv('EMBEDDING_PASS')}"}
|
||||
self.normalizer = Normalizer()
|
||||
|
||||
def preprocess_embedder(self, text:str):
|
||||
text = text.replace("\n", ".")
|
||||
text = self.normalizer.normalize(text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def embed_texts(self, texts:list[str])->list[list[float]]:
|
||||
"""
|
||||
Embed texts using the model.
|
||||
"""
|
||||
if texts == []:
|
||||
return []
|
||||
|
||||
texts = [self.preprocess_embedder(text) for text in texts]
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": texts
|
||||
}
|
||||
responses = requests.post("http://78.38.161.78:3094/v1/embeddings", headers=self.headers, json=payload)
|
||||
embeddings = [np.array(response["embedding"]) for response in responses.json()["data"]]
|
||||
|
||||
return embeddings
|
||||
|
||||
360
data_preprocess_notebook/data_loader_gholam_pquad.ipynb
Normal file
360
data_preprocess_notebook/data_loader_gholam_pquad.ipynb
Normal file
@ -0,0 +1,360 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a78759c8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/datasets/load.py:1461: FutureWarning: The repository for Gholamreza/pquad contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Gholamreza/pquad\n",
|
||||
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
|
||||
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"dataset = load_dataset(\"Gholamreza/pquad\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c91f659a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"48273\n",
|
||||
"63994\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all_dataset = []\n",
|
||||
"for data in dataset[\"train\"]:\n",
|
||||
" if len(data[\"answers\"][\"text\"]) > 0:\n",
|
||||
" all_dataset.append({'question': data['question'], 'passage_positive': [data['context']], 'passage_negative': [], 'passage_negative_random': []})\n",
|
||||
" # else:\n",
|
||||
" # all_dataset.append({'question': data['question'], 'passage_positive': [], 'passage_negative': [data['context']]})\n",
|
||||
"\n",
|
||||
"print(len(all_dataset))\n",
|
||||
"print(len(dataset[\"train\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "d66809ce",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'question': 'جنگ جهانی اول در چه تاریخی پایان یافت؟', 'passgae_positive': [], 'passgae_negative': ['در سال ۱۸۷۱ امپراتوری آلمان با اتحاد پروس و کنفدراسیون جرمن شمالی توسط اتو ون بیسمارک به وجود آمد. این کشور قدرتمند تا سال ۱۹۱۸ ادامه یافت و با عنوان رایش دوم مشهور شد. بیسمارک توانست استان\\u200cهای جدید زیادی را طی جنگ\\u200cهای مبتکرانهٔ کوتاه و دیپلماتیک به دست آورد. او با اتریش هم پیمان شد تا دانمارک را شکست دهد و ناحیهٔ شلزویگ-هولشتاین را تصرف کند. او جنگ اتریش و پروس (آسترو-پروسیان) را آغاز کرد و پیروز شد اما اینکار فقط برای این بود که ایتالیا طرف آلمان را بگیرد. سپس پروس وارد جنگ فرانسه و پروس (فرانکو-پروسین) (۷۱-۱۸۷۰) شد و توانست شکست کاملی به فرانسه وارد سازد. ویلهلم اول به عنوان آخرین توهین به فرانسوی\\u200cها در کاخ ورسای در قلب فرانسه به عنوان امپراتور آلمان سوگند خورد. امپراتوری آلمان تا پایان جنگ جهانی اول یعنی زمانی که فرانسه توانست در پیمان ورسای تلافی بکند در اوج خود بود.']}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(all_dataset[1240])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "9a566e69",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'id': 123802.0,\n",
|
||||
" 'title': 'قرن نوزدهم',\n",
|
||||
" 'context': 'جنگ داخلی ایالات متحده از سال ۱۸۶۱ تا ۱۸۶۵ طول کشید. در طول جنگ آبراهام لینکلن رئیس\\u200cجمهور بود. امروزه او را به عنوان یکی از بزرگ\\u200cترین رهبران جهان غرب تلقی می\\u200cکنند. در همین زمان، آمدن نیروی بخار در کنار انقلاب صنعتی رو به رشد، گسترش شدیدی در کارخانجات صنعتی ایجاد کرد. در سال ۱۸۷۸ توماس ادیسون لامپ برق جدید خود را به نمایش گذاشت و در طول یک دهه سیستم توزیع برق بزرگی را در تمام کشور راه\\u200cاندازی کرد. تأثیر اقتصادی به تدریج به خارج از ایالات متحده و به سمت اقیانوس آرام و آمریکای لاتین گسترش پیدا کرد.',\n",
|
||||
" 'question': 'علت گسترش شدید کارخانجات صنعتی چه بود؟',\n",
|
||||
" 'answers': {'text': ['آمدن نیروی بخار در کنار انقلاب صنعتی رو به رشد'],\n",
|
||||
" 'answer_start': [180]}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataset[\"train\"][1251]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "08a2e2d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from hazm import Normalizer\n",
|
||||
"import requests\n",
|
||||
"import numpy as np\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"load_dotenv()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class TextEmbedder:\n",
|
||||
" def __init__(self, model_name=\"BAAI/bge-m3\"):\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.headers = {\"Content-Type\": \"application/json\", \"Authorization\": f\"Bearer {os.getenv('EMBEDDING_PASS')}\"}\n",
|
||||
" self.normalizer = Normalizer()\n",
|
||||
" \n",
|
||||
" def preprocess_embedder(self, text:str):\n",
|
||||
" text = text.replace(\"\\n\", \".\")\n",
|
||||
" text = self.normalizer.normalize(text)\n",
|
||||
" \n",
|
||||
" return text\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def embed_texts(self, texts:list[str])->list[list[float]]:\n",
|
||||
" \"\"\"\n",
|
||||
" Embed texts using the model.\n",
|
||||
" \"\"\"\n",
|
||||
" if texts == []:\n",
|
||||
" return []\n",
|
||||
" \n",
|
||||
" texts = [self.preprocess_embedder(text) for text in texts]\n",
|
||||
" \n",
|
||||
" payload = {\n",
|
||||
" \"model\": self.model_name,\n",
|
||||
" \"input\": texts\n",
|
||||
" }\n",
|
||||
" responses = requests.post(\"http://78.38.161.78:3094/v1/embeddings\", headers=self.headers, json=payload)\n",
|
||||
" embeddings = [np.array(response[\"embedding\"]) for response in responses.json()[\"data\"]]\n",
|
||||
" \n",
|
||||
" return embeddings\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "978e4ac3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"calculate question embeddings\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 5/5 [00:00<00:00, 6.35it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"calculate passage positive embeddings\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 5/5 [00:01<00:00, 4.47it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"calculate passage negative embeddings\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 5/5 [00:00<00:00, 21140.65it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"getting random negative passages\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 220/220 [00:00<00:00, 3699.93it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"THRESHOLD_MULTIPLY = 0.9\n",
|
||||
"RANDOM_NEGATIVE_COUNT = 50\n",
|
||||
"batch_size = 50\n",
|
||||
"\n",
|
||||
"all_dataset = all_dataset[1500:1500+220]\n",
|
||||
"\n",
|
||||
"text_embedder = TextEmbedder()\n",
|
||||
"len_dataset = len(all_dataset)\n",
|
||||
"all_dataset_embeddings = [{'question_embedding': \"\", 'passage_positive_embedding': [], 'passage_negative_embeddings': []} for _ in range(len_dataset)]\n",
|
||||
"\n",
|
||||
"print(\"calculate question embeddings\")\n",
|
||||
"# calculate question embeddings\n",
|
||||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||||
"\n",
|
||||
" question_list = []\n",
|
||||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||||
" question_list.append(all_dataset[id]['question'])\n",
|
||||
"\n",
|
||||
" question_embeddings = text_embedder.embed_texts(question_list)\n",
|
||||
"\n",
|
||||
" count = 0 \n",
|
||||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||||
" all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count]\n",
|
||||
" count += 1\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"calculate passage positive embeddings\")\n",
|
||||
"# calculate passage positive embeddings\n",
|
||||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||||
"\n",
|
||||
" passage_positive_list = []\n",
|
||||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||||
" for passage in all_dataset[id]['passage_positive']:\n",
|
||||
" passage_positive_list.append(passage)\n",
|
||||
"\n",
|
||||
" passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list)\n",
|
||||
"\n",
|
||||
" count = 0\n",
|
||||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||||
" for passage_id in range(len(all_dataset[id]['passage_positive'])):\n",
|
||||
" all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count])\n",
|
||||
" count += 1\n",
|
||||
"\n",
|
||||
"print(\"calculate passage negative embeddings\")\n",
|
||||
"# calculate passage negative embeddings\n",
|
||||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||||
"\n",
|
||||
" passage_negative_list = []\n",
|
||||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||||
" for passage in all_dataset[id]['passage_negative']:\n",
|
||||
" passage_negative_list.append(passage)\n",
|
||||
"\n",
|
||||
" passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list)\n",
|
||||
"\n",
|
||||
" count = 0\n",
|
||||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||||
" for passage_id in range(len(all_dataset[id]['passage_negative'])):\n",
|
||||
" all_dataset_embeddings[id]['passage_negative_embeddings'].append(passage_negative_embeddings[count])\n",
|
||||
" count += 1\n",
|
||||
"\n",
|
||||
"print(\"getting random negative passages\")\n",
|
||||
"for id in tqdm(range(len_dataset)):\n",
|
||||
" question_embeddings = all_dataset_embeddings[id]['question_embedding']\n",
|
||||
" passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0]\n",
|
||||
"\n",
|
||||
" score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings)\n",
|
||||
"\n",
|
||||
" while len(all_dataset[id]['passage_negative_random']) < RANDOM_NEGATIVE_COUNT:\n",
|
||||
" random_id = random.randint(0, len_dataset - 1)\n",
|
||||
" if random_id != id:\n",
|
||||
" all_passages_embedding = all_dataset_embeddings[random_id]['passage_negative_embeddings'] + all_dataset_embeddings[random_id]['passage_positive_embedding']\n",
|
||||
" all_passages = all_dataset[random_id]['passage_negative'] + all_dataset[random_id]['passage_positive']\n",
|
||||
" random_passage_id = random.randint(0, len(all_passages) - 1)\n",
|
||||
" random_passage_embeddings = all_passages_embedding[random_passage_id]\n",
|
||||
" score_question_random_passage = np.dot(question_embeddings, random_passage_embeddings)\n",
|
||||
"\n",
|
||||
" if score_question_random_passage < THRESHOLD_MULTIPLY * score_question_passage_positive:\n",
|
||||
" all_dataset[id]['passage_negative_random'].append(all_passages[random_passage_id])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "252b18c2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"100"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(all_dataset[:100])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "d958564d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"with open(\"./x.json\", 'w', encoding='utf-8') as f:\n",
|
||||
" json.dump(all_dataset[:100], f, ensure_ascii=False, indent=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b86d5fdf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
298
data_preprocess_notebook/data_loader_synthetic.ipynb
Normal file
298
data_preprocess_notebook/data_loader_synthetic.ipynb
Normal file
@ -0,0 +1,298 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"id": "a78759c8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1000212\n",
|
||||
"250666\n",
|
||||
"270642\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"dataset_synthetic_scores = []\n",
|
||||
"with open('/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/train.jsonl', 'r', encoding='utf-8') as f:\n",
|
||||
" for line in f:\n",
|
||||
" data = json.loads(line)\n",
|
||||
" dataset_synthetic_scores.append(data)\n",
|
||||
"\n",
|
||||
"dataset_synthetic_queries = {}\n",
|
||||
"with open('/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/queries.jsonl', 'r', encoding='utf-8') as f:\n",
|
||||
" for line in f:\n",
|
||||
" json_data = json.loads(line)\n",
|
||||
" dataset_synthetic_queries[json_data['_id']] = json_data\n",
|
||||
"\n",
|
||||
"dataset_synthetic_corpus = {}\n",
|
||||
"with open('/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/corpus.jsonl', 'r', encoding='utf-8') as f:\n",
|
||||
" for line in f:\n",
|
||||
" json_data = json.loads(line)\n",
|
||||
" dataset_synthetic_corpus[json_data['_id']] = json_data\n",
|
||||
"\n",
|
||||
"print(len(dataset_synthetic_scores))\n",
|
||||
"print(len(dataset_synthetic_queries))\n",
|
||||
"print(len(dataset_synthetic_corpus))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"id": "bbb2657f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"223423\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all_dataset = {}\n",
|
||||
"for data_topic in dataset_synthetic_scores:\n",
|
||||
" \n",
|
||||
" query_id = data_topic['query-id']\n",
|
||||
" corpus_id = int(data_topic['corpus-id'])\n",
|
||||
" score = data_topic['score']\n",
|
||||
"\n",
|
||||
" if (query_id in dataset_synthetic_queries) and (corpus_id in dataset_synthetic_corpus):\n",
|
||||
" if score == \"1\":\n",
|
||||
" if query_id in all_dataset:\n",
|
||||
" all_dataset[query_id]['passgae_positive'].append(dataset_synthetic_corpus[corpus_id]['text'])\n",
|
||||
" else:\n",
|
||||
" all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'], \n",
|
||||
" 'passgae_positive': [dataset_synthetic_corpus[corpus_id]['text']], \n",
|
||||
" 'passgae_negative': [], \n",
|
||||
" 'passage_negative_random': []}\n",
|
||||
" else:\n",
|
||||
" if query_id in all_dataset:\n",
|
||||
" all_dataset[query_id]['passgae_negative'].append(dataset_synthetic_corpus[corpus_id]['text'])\n",
|
||||
" else:\n",
|
||||
" all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'], \n",
|
||||
" 'passgae_positive': [],\n",
|
||||
" 'passgae_negative': [dataset_synthetic_corpus[corpus_id]['text']],\n",
|
||||
" 'passage_negative_random': []}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"all_dataset = list(all_dataset.values())\n",
|
||||
"print(len(all_dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"id": "74ef02a1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'question': 'چه کسانی باید کتاب لوازم نویسندگی را بخوانند؟',\n",
|
||||
" 'passgae_positive': ['این کتاب به ویژه برای علاقه مندان تازه کار به هنر داستان نویسی مفید است. افرادی که مشتاق یادگیری نمایشنامه، فیلمنامه، حکایت یا قصه هستند می توانند از این راهنمای کاربردی بهره ببرند.'],\n",
|
||||
" 'passgae_negative': ['این کتاب در دسته بندی پژوهش ادبی، مجموعه آموزش نویسندگی قرار می گیرد. همچنین این کتاب به عنوان یک نقشه راه برای افرادی که ایده ای را پرورش داده اند و قصد دارند آن را با قلمی رسا بیان کنند، پیشنهاد شده است.',\n",
|
||||
" 'این کتاب به ویژه برای افرادی که در حوزه ادبیات کودک فعالیت می\\u200cکنند یا قصد ورود به این حوزه را دارند، مفید و سازنده است. خواندن این کتاب می\\u200cتواند به نویسندگان کمک کند تا با بازار نویسندگی برای کودکان آشنا شوند و مهارت\\u200cهای لازم برای نوشتن آثار مناسب برای این گروه سنی را کسب کنند.',\n",
|
||||
" \"کتاب 'همه چیز درباره نویسندگی خلاق' نکات کلیدی و چالش\\u200cهای مختلفی را برای نویسندگان تازه\\u200cکار ارائه می\\u200cدهد. این نکات شامل تکنیک\\u200cهای فرّار برای غلبه بر خشک\\u200cطبعی در نویسندگی، منابع الهام\\u200cبخش، مثال\\u200cها و گزیده\\u200cهای مختلف است. همچنین، مصاحبه\\u200cهایی با نویسندگان موفق در این کتاب وجود دارد که می\\u200cتواند به خوانندگان انگیزه و الهام بیشتری برای نوشتن بدهد.\"],\n",
|
||||
" 'passage_negative_random': []}"
|
||||
]
|
||||
},
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all_dataset[71]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"id": "8e167b4b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"question:\n",
|
||||
"چگونه خنده میتواند به کاهش استرس کمک کند؟\n",
|
||||
"--------------------------------\n",
|
||||
"passgae_positive:\n",
|
||||
"خنده به عنوان یک واکنش طبیعی بدن، میتواند به کاهش سطح هورمونهای استرس مانند کورتیزول کمک کند. تحقیقی از دانشگاه کانزاس نشان داده است که خندیدن در شرایط استرسزا، ضربان قلب افراد را کاهش میدهد و به آنها احساس آرامش بیشتری میدهد. این اثرات مثبت به ویژه در خندههای اجتماعی مشهود است، که نشان میدهد حتی لبخند زدن نیز میتواند به کاهش استرس کمک کند.\n",
|
||||
"--------------------------------\n",
|
||||
"خندیدن به اشتباهات میتواند به عنوان یک مکانیزم مقابلهای عمل کند که به افراد کمک میکند تا با فشارهای روانی و استرسهای روزمره کنار بیایند. این عمل نه تنها به کاهش تنشهای عاطفی کمک میکند، بلکه میتواند به بهبود روابط اجتماعی نیز منجر شود. در واقع، افرادی که قادر به خندیدن به اشتباهات خود هستند، معمولاً احساس راحتی بیشتری در تعاملات اجتماعی دارند و میتوانند به راحتی با دیگران ارتباط برقرار کنند.\n",
|
||||
"{{\"result\": \"1\"}}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'result' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[48], line 57\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m match:\n\u001b[1;32m 56\u001b[0m result \u001b[38;5;241m=\u001b[39m match\u001b[38;5;241m.\u001b[39mgroup(\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mresult\u001b[49m)\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m--------------------------------\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'result' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"load_dotenv()\n",
|
||||
"\n",
|
||||
"qwen = False\n",
|
||||
"if qwen:\n",
|
||||
" url = \"https://qwen3.chatllm.aiengines.ir/v1/chat/completions\"\n",
|
||||
" model = \"Qwen/Qwen3-4B-Instruct-2507\"\n",
|
||||
" headers = {\"Content-Type\": \"application/json\", \"Authorization\": f\"Bearer {os.getenv('LLM_AS_RERANKER_PASS')}\"}\n",
|
||||
"else:\n",
|
||||
" url = \"http://192.168.130.206:4001/v1/chat/completions\"\n",
|
||||
" model = \"google/gemma-3-27b-it\"\n",
|
||||
" headers = {\"Content-Type\": \"application/json\"}\n",
|
||||
"\n",
|
||||
"instruction = \"\"\"\n",
|
||||
"You are a helpful assistant that help me to find that the text is relevant to the question or not.\n",
|
||||
"You are given a question and a text.\n",
|
||||
"You must evaluate the text based on the question and return \"1\" if the text is relevant to the question and \"0\" if the text is not relevant to the question.\n",
|
||||
" \n",
|
||||
"be carefull, I have chosen the text randomly from my dataset so the text must answer the question independently.\n",
|
||||
"You must return the result in the following format:\n",
|
||||
"{{\"result\": \"1\" or \"0\"}}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"id = 7850\n",
|
||||
"\n",
|
||||
"print(\"question:\")\n",
|
||||
"print(all_dataset[id][\"question\"])\n",
|
||||
"print(\"--------------------------------\")\n",
|
||||
"print(\"passgae_positive:\")\n",
|
||||
"print(all_dataset[id][\"passgae_positive\"][0])\n",
|
||||
"print(\"--------------------------------\")\n",
|
||||
"for i in range(len(all_dataset[id][\"passgae_negative\"])):\n",
|
||||
" question, passgae_negative = all_dataset[id]['question'], all_dataset[id][\"passgae_negative\"][i]\n",
|
||||
" input_message = f\"\"\"{{\"question\": \"{question}\", \"text\": \"{passgae_negative}\"}}\"\"\"\n",
|
||||
" messages = [{\"role\": \"system\", \"content\": instruction}, {\"role\": \"user\", \"content\": input_message}]\n",
|
||||
"\n",
|
||||
" payload = {\n",
|
||||
" \"model\": model,\n",
|
||||
" \"messages\": messages,\n",
|
||||
" \"max_tokens\": 100\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" req = requests.post(url, headers=headers, json=payload)\n",
|
||||
" print(all_dataset[id][\"passgae_negative\"][i])\n",
|
||||
" print(req.json()['choices'][0]['message']['content'])\n",
|
||||
" out = req.json()['choices'][0]['message']['content']\n",
|
||||
" \n",
|
||||
" match = re.search(r'\"result\":\\s*([\\d\\.]+)', out)\n",
|
||||
"\n",
|
||||
" if match:\n",
|
||||
" result = match.group(1)\n",
|
||||
" print(result)\n",
|
||||
" print(\"--------------------------------\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"id": "24b13c5b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'{{\"result\": \"1\"}}'"
|
||||
]
|
||||
},
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"id": "30586c26",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"out = '{\"result\": \"1\"}'\n",
|
||||
"match = re.search(r'\"result\":\\s*\"?([\\d\\.]+)\"?', out)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"result = match.group(1)\n",
|
||||
"\n",
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"id": "9fb33634",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"r = [\"\"]\n",
|
||||
"if not r:\n",
|
||||
" print(\"empty\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4917b3a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
2
requirements.txt
Normal file
2
requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
python-dotenv==1.1.1
|
||||
hazm
|
||||
Loading…
x
Reference in New Issue
Block a user