361 lines
14 KiB
Plaintext
361 lines
14 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "a78759c8",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/datasets/load.py:1461: FutureWarning: The repository for Gholamreza/pquad contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Gholamreza/pquad\n",
|
||
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
|
||
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
|
||
" warnings.warn(\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from datasets import load_dataset\n",
|
||
"\n",
|
||
"dataset = load_dataset(\"Gholamreza/pquad\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "c91f659a",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"48273\n",
|
||
"63994\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"all_dataset = []\n",
|
||
"for data in dataset[\"train\"]:\n",
|
||
" if len(data[\"answers\"][\"text\"]) > 0:\n",
|
||
" all_dataset.append({'question': data['question'], 'passage_positive': [data['context']], 'passage_negative': [], 'passage_negative_random': []})\n",
|
||
" # else:\n",
|
||
" # all_dataset.append({'question': data['question'], 'passage_positive': [], 'passage_negative': [data['context']]})\n",
|
||
"\n",
|
||
"print(len(all_dataset))\n",
|
||
"print(len(dataset[\"train\"]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "d66809ce",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{'question': 'جنگ جهانی اول در چه تاریخی پایان یافت؟', 'passgae_positive': [], 'passgae_negative': ['در سال ۱۸۷۱ امپراتوری آلمان با اتحاد پروس و کنفدراسیون جرمن شمالی توسط اتو ون بیسمارک به وجود آمد. این کشور قدرتمند تا سال ۱۹۱۸ ادامه یافت و با عنوان رایش دوم مشهور شد. بیسمارک توانست استان\\u200cهای جدید زیادی را طی جنگ\\u200cهای مبتکرانهٔ کوتاه و دیپلماتیک به دست آورد. او با اتریش هم پیمان شد تا دانمارک را شکست دهد و ناحیهٔ شلزویگ-هولشتاین را تصرف کند. او جنگ اتریش و پروس (آسترو-پروسیان) را آغاز کرد و پیروز شد اما اینکار فقط برای این بود که ایتالیا طرف آلمان را بگیرد. سپس پروس وارد جنگ فرانسه و پروس (فرانکو-پروسین) (۷۱-۱۸۷۰) شد و توانست شکست کاملی به فرانسه وارد سازد. ویلهلم اول به عنوان آخرین توهین به فرانسوی\\u200cها در کاخ ورسای در قلب فرانسه به عنوان امپراتور آلمان سوگند خورد. امپراتوری آلمان تا پایان جنگ جهانی اول یعنی زمانی که فرانسه توانست در پیمان ورسای تلافی بکند در اوج خود بود.']}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(all_dataset[1240])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "9a566e69",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'id': 123802.0,\n",
|
||
" 'title': 'قرن نوزدهم',\n",
|
||
" 'context': 'جنگ داخلی ایالات متحده از سال ۱۸۶۱ تا ۱۸۶۵ طول کشید. در طول جنگ آبراهام لینکلن رئیس\\u200cجمهور بود. امروزه او را به عنوان یکی از بزرگ\\u200cترین رهبران جهان غرب تلقی می\\u200cکنند. در همین زمان، آمدن نیروی بخار در کنار انقلاب صنعتی رو به رشد، گسترش شدیدی در کارخانجات صنعتی ایجاد کرد. در سال ۱۸۷۸ توماس ادیسون لامپ برق جدید خود را به نمایش گذاشت و در طول یک دهه سیستم توزیع برق بزرگی را در تمام کشور راه\\u200cاندازی کرد. تأثیر اقتصادی به تدریج به خارج از ایالات متحده و به سمت اقیانوس آرام و آمریکای لاتین گسترش پیدا کرد.',\n",
|
||
" 'question': 'علت گسترش شدید کارخانجات صنعتی چه بود؟',\n",
|
||
" 'answers': {'text': ['آمدن نیروی بخار در کنار انقلاب صنعتی رو به رشد'],\n",
|
||
" 'answer_start': [180]}}"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"dataset[\"train\"][1251]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"id": "08a2e2d3",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from hazm import Normalizer\n",
|
||
"import requests\n",
|
||
"import numpy as np\n",
|
||
"from dotenv import load_dotenv\n",
|
||
"import os\n",
|
||
"\n",
|
||
"load_dotenv()\n",
|
||
"\n",
|
||
"\n",
|
||
"class TextEmbedder:\n",
|
||
" def __init__(self, model_name=\"BAAI/bge-m3\"):\n",
|
||
" self.model_name = model_name\n",
|
||
" self.headers = {\"Content-Type\": \"application/json\", \"Authorization\": f\"Bearer {os.getenv('EMBEDDING_PASS')}\"}\n",
|
||
" self.normalizer = Normalizer()\n",
|
||
" \n",
|
||
" def preprocess_embedder(self, text:str):\n",
|
||
" text = text.replace(\"\\n\", \".\")\n",
|
||
" text = self.normalizer.normalize(text)\n",
|
||
" \n",
|
||
" return text\n",
|
||
" \n",
|
||
"\n",
|
||
" def embed_texts(self, texts:list[str])->list[list[float]]:\n",
|
||
" \"\"\"\n",
|
||
" Embed texts using the model.\n",
|
||
" \"\"\"\n",
|
||
" if texts == []:\n",
|
||
" return []\n",
|
||
" \n",
|
||
" texts = [self.preprocess_embedder(text) for text in texts]\n",
|
||
" \n",
|
||
" payload = {\n",
|
||
" \"model\": self.model_name,\n",
|
||
" \"input\": texts\n",
|
||
" }\n",
|
||
" responses = requests.post(\"http://78.38.161.78:3094/v1/embeddings\", headers=self.headers, json=payload)\n",
|
||
" embeddings = [np.array(response[\"embedding\"]) for response in responses.json()[\"data\"]]\n",
|
||
" \n",
|
||
" return embeddings\n",
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "978e4ac3",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"calculate question embeddings\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 5/5 [00:00<00:00, 6.35it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"calculate passage positive embeddings\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 5/5 [00:01<00:00, 4.47it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"calculate passage negative embeddings\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 5/5 [00:00<00:00, 21140.65it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"getting random negative passages\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 220/220 [00:00<00:00, 3699.93it/s]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import random\n",
|
||
"from tqdm import tqdm\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"THRESHOLD_MULTIPLY = 0.9\n",
|
||
"RANDOM_NEGATIVE_COUNT = 50\n",
|
||
"batch_size = 50\n",
|
||
"\n",
|
||
"all_dataset = all_dataset[1500:1500+220]\n",
|
||
"\n",
|
||
"text_embedder = TextEmbedder()\n",
|
||
"len_dataset = len(all_dataset)\n",
|
||
"all_dataset_embeddings = [{'question_embedding': \"\", 'passage_positive_embedding': [], 'passage_negative_embeddings': []} for _ in range(len_dataset)]\n",
|
||
"\n",
|
||
"print(\"calculate question embeddings\")\n",
|
||
"# calculate question embeddings\n",
|
||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||
"\n",
|
||
" question_list = []\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" question_list.append(all_dataset[id]['question'])\n",
|
||
"\n",
|
||
" question_embeddings = text_embedder.embed_texts(question_list)\n",
|
||
"\n",
|
||
" count = 0 \n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count]\n",
|
||
" count += 1\n",
|
||
"\n",
|
||
"\n",
|
||
"print(\"calculate passage positive embeddings\")\n",
|
||
"# calculate passage positive embeddings\n",
|
||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||
"\n",
|
||
" passage_positive_list = []\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" for passage in all_dataset[id]['passage_positive']:\n",
|
||
" passage_positive_list.append(passage)\n",
|
||
"\n",
|
||
" passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list)\n",
|
||
"\n",
|
||
" count = 0\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" for passage_id in range(len(all_dataset[id]['passage_positive'])):\n",
|
||
" all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count])\n",
|
||
" count += 1\n",
|
||
"\n",
|
||
"print(\"calculate passage negative embeddings\")\n",
|
||
"# calculate passage negative embeddings\n",
|
||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||
"\n",
|
||
" passage_negative_list = []\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" for passage in all_dataset[id]['passage_negative']:\n",
|
||
" passage_negative_list.append(passage)\n",
|
||
"\n",
|
||
" passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list)\n",
|
||
"\n",
|
||
" count = 0\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" for passage_id in range(len(all_dataset[id]['passage_negative'])):\n",
|
||
" all_dataset_embeddings[id]['passage_negative_embeddings'].append(passage_negative_embeddings[count])\n",
|
||
" count += 1\n",
|
||
"\n",
|
||
"print(\"getting random negative passages\")\n",
|
||
"for id in tqdm(range(len_dataset)):\n",
|
||
" question_embeddings = all_dataset_embeddings[id]['question_embedding']\n",
|
||
" passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0]\n",
|
||
"\n",
|
||
" score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings)\n",
|
||
"\n",
|
||
" while len(all_dataset[id]['passage_negative_random']) < RANDOM_NEGATIVE_COUNT:\n",
|
||
" random_id = random.randint(0, len_dataset - 1)\n",
|
||
" if random_id != id:\n",
|
||
" all_passages_embedding = all_dataset_embeddings[random_id]['passage_negative_embeddings'] + all_dataset_embeddings[random_id]['passage_positive_embedding']\n",
|
||
" all_passages = all_dataset[random_id]['passage_negative'] + all_dataset[random_id]['passage_positive']\n",
|
||
" random_passage_id = random.randint(0, len(all_passages) - 1)\n",
|
||
" random_passage_embeddings = all_passages_embedding[random_passage_id]\n",
|
||
" score_question_random_passage = np.dot(question_embeddings, random_passage_embeddings)\n",
|
||
"\n",
|
||
" if score_question_random_passage < THRESHOLD_MULTIPLY * score_question_passage_positive:\n",
|
||
" all_dataset[id]['passage_negative_random'].append(all_passages[random_passage_id])\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "252b18c2",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"100"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"len(all_dataset[:100])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "d958564d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import json\n",
|
||
"\n",
|
||
"with open(\"./x.json\", 'w', encoding='utf-8') as f:\n",
|
||
" json.dump(all_dataset[:100], f, ensure_ascii=False, indent=4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b86d5fdf",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|