357 lines
13 KiB
Plaintext
357 lines
13 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "a78759c8",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from datasets import load_dataset\n",
|
||
"\n",
|
||
"dataset = load_dataset(\"Gholamreza/pquad\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "c91f659a",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"48273\n",
|
||
"63994\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"all_dataset = []\n",
|
||
"for data in dataset[\"train\"]:\n",
|
||
" if len(data[\"answers\"][\"text\"]) > 0:\n",
|
||
" all_dataset.append({'question': data['question'], 'passage_positive': [data['context']], 'passage_negative': [], 'passage_negative_random': []})\n",
|
||
" # else:\n",
|
||
" # all_dataset.append({'question': data['question'], 'passage_positive': [], 'passage_negative': [data['context']]})\n",
|
||
"\n",
|
||
"print(len(all_dataset))\n",
|
||
"print(len(dataset[\"train\"]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "d66809ce",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{'question': 'سام میرزا در چه تاریخی توسط نادرشاه دستگیر شد؟', 'passage_positive': ['این ضربت سخت خیال نادر را پریشان کرد و رضا قلی میرزا را که در رکاب بود در طهران گذاشت و خود به داغستان رفت در این سفر اگرچه بعضی از رؤسای طوایف لزکی از در اطاعت درآمدند لیکن غالب سکنه داغستان به قلل جبال پرارتفاع پناه گرفتند و از هر طرف به تعرّض اردوی نادر دست زدند و لطمات بسیار به ایشان وارد آوردند حتّی موقعی به خیمه خود نادر نیز تعرّض رساندند. در رمضان ۱۱۵۴ موقعیکه نادر هنوز در داغستان بود غلامی را که مرتکب انداختن تیر در جنگل سوادکوه شده بود بخدمت او آوردند. نادر او را کور کرد. شخصی بنام سام میرزا که به ادّعای فرزندی شاه سلطان حسین در آذربایجان به سلطنت طلبی برخاسته و محمّد خان پسر سرخای خان لزگی و خوانین دربند و داغستان را با خود همدست نموده بود. نادر توسّط نصر اللّه میرزا و چند تن از سرداران خود انقلاب این حدود را بالاخره خواباند و سام میرزا در ذی\\u200cالقعده ۱۱۵۶ دستگیر گردید.'], 'passage_negative': [], 'passage_negative_random': []}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(all_dataset[1241])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "9a566e69",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'id': 123802.0,\n",
|
||
" 'title': 'قرن نوزدهم',\n",
|
||
" 'context': 'جنگ داخلی ایالات متحده از سال ۱۸۶۱ تا ۱۸۶۵ طول کشید. در طول جنگ آبراهام لینکلن رئیس\\u200cجمهور بود. امروزه او را به عنوان یکی از بزرگ\\u200cترین رهبران جهان غرب تلقی می\\u200cکنند. در همین زمان، آمدن نیروی بخار در کنار انقلاب صنعتی رو به رشد، گسترش شدیدی در کارخانجات صنعتی ایجاد کرد. در سال ۱۸۷۸ توماس ادیسون لامپ برق جدید خود را به نمایش گذاشت و در طول یک دهه سیستم توزیع برق بزرگی را در تمام کشور راه\\u200cاندازی کرد. تأثیر اقتصادی به تدریج به خارج از ایالات متحده و به سمت اقیانوس آرام و آمریکای لاتین گسترش پیدا کرد.',\n",
|
||
" 'question': 'علت گسترش شدید کارخانجات صنعتی چه بود؟',\n",
|
||
" 'answers': {'text': ['آمدن نیروی بخار در کنار انقلاب صنعتی رو به رشد'],\n",
|
||
" 'answer_start': [180]}}"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"dataset[\"train\"][1251]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"id": "08a2e2d3",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from hazm import Normalizer\n",
|
||
"import requests\n",
|
||
"import numpy as np\n",
|
||
"from dotenv import load_dotenv\n",
|
||
"import os\n",
|
||
"\n",
|
||
"load_dotenv()\n",
|
||
"\n",
|
||
"\n",
|
||
"class TextEmbedder:\n",
|
||
" def __init__(self, model_name=\"BAAI/bge-m3\"):\n",
|
||
" self.model_name = model_name\n",
|
||
" self.headers = {\"Content-Type\": \"application/json\", \"Authorization\": f\"Bearer {os.getenv('EMBEDDING_PASS')}\"}\n",
|
||
" self.normalizer = Normalizer()\n",
|
||
" \n",
|
||
" def preprocess_embedder(self, text:str):\n",
|
||
" text = text.replace(\"\\n\", \".\")\n",
|
||
" text = self.normalizer.normalize(text)\n",
|
||
" \n",
|
||
" return text\n",
|
||
" \n",
|
||
"\n",
|
||
" def embed_texts(self, texts:list[str])->list[list[float]]:\n",
|
||
" \"\"\"\n",
|
||
" Embed texts using the model.\n",
|
||
" \"\"\"\n",
|
||
" if texts == []:\n",
|
||
" return []\n",
|
||
" \n",
|
||
" texts = [self.preprocess_embedder(text) for text in texts]\n",
|
||
" \n",
|
||
" payload = {\n",
|
||
" \"model\": self.model_name,\n",
|
||
" \"input\": texts\n",
|
||
" }\n",
|
||
" responses = requests.post(\"http://78.38.161.78:3094/v1/embeddings\", headers=self.headers, json=payload)\n",
|
||
" embeddings = [np.array(response[\"embedding\"]) for response in responses.json()[\"data\"]]\n",
|
||
" \n",
|
||
" return embeddings\n",
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "978e4ac3",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"calculate question embeddings\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 5/5 [00:00<00:00, 6.35it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"calculate passage positive embeddings\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 5/5 [00:01<00:00, 4.47it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"calculate passage negative embeddings\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 5/5 [00:00<00:00, 21140.65it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"getting random negative passages\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 220/220 [00:00<00:00, 3699.93it/s]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import random\n",
|
||
"from tqdm import tqdm\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"THRESHOLD_MULTIPLY = 0.9\n",
|
||
"RANDOM_NEGATIVE_COUNT = 50\n",
|
||
"batch_size = 50\n",
|
||
"\n",
|
||
"all_dataset = all_dataset[1500:1500+220]\n",
|
||
"\n",
|
||
"text_embedder = TextEmbedder()\n",
|
||
"len_dataset = len(all_dataset)\n",
|
||
"all_dataset_embeddings = [{'question_embedding': \"\", 'passage_positive_embedding': [], 'passage_negative_embeddings': []} for _ in range(len_dataset)]\n",
|
||
"\n",
|
||
"print(\"calculate question embeddings\")\n",
|
||
"# calculate question embeddings\n",
|
||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||
"\n",
|
||
" question_list = []\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" question_list.append(all_dataset[id]['question'])\n",
|
||
"\n",
|
||
" question_embeddings = text_embedder.embed_texts(question_list)\n",
|
||
"\n",
|
||
" count = 0 \n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count]\n",
|
||
" count += 1\n",
|
||
"\n",
|
||
"\n",
|
||
"print(\"calculate passage positive embeddings\")\n",
|
||
"# calculate passage positive embeddings\n",
|
||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||
"\n",
|
||
" passage_positive_list = []\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" for passage in all_dataset[id]['passage_positive']:\n",
|
||
" passage_positive_list.append(passage)\n",
|
||
"\n",
|
||
" passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list)\n",
|
||
"\n",
|
||
" count = 0\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" for passage_id in range(len(all_dataset[id]['passage_positive'])):\n",
|
||
" all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count])\n",
|
||
" count += 1\n",
|
||
"\n",
|
||
"print(\"calculate passage negative embeddings\")\n",
|
||
"# calculate passage negative embeddings\n",
|
||
"for i in tqdm(range(0, len_dataset, batch_size)):\n",
|
||
"\n",
|
||
" passage_negative_list = []\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" for passage in all_dataset[id]['passage_negative']:\n",
|
||
" passage_negative_list.append(passage)\n",
|
||
"\n",
|
||
" passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list)\n",
|
||
"\n",
|
||
" count = 0\n",
|
||
" for id in range(i, min(i + batch_size, len_dataset)):\n",
|
||
" for passage_id in range(len(all_dataset[id]['passage_negative'])):\n",
|
||
" all_dataset_embeddings[id]['passage_negative_embeddings'].append(passage_negative_embeddings[count])\n",
|
||
" count += 1\n",
|
||
"\n",
|
||
"print(\"getting random negative passages\")\n",
|
||
"for id in tqdm(range(len_dataset)):\n",
|
||
" question_embeddings = all_dataset_embeddings[id]['question_embedding']\n",
|
||
" passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0]\n",
|
||
"\n",
|
||
" score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings)\n",
|
||
"\n",
|
||
" while len(all_dataset[id]['passage_negative_random']) < RANDOM_NEGATIVE_COUNT:\n",
|
||
" random_id = random.randint(0, len_dataset - 1)\n",
|
||
" if random_id != id:\n",
|
||
" all_passages_embedding = all_dataset_embeddings[random_id]['passage_negative_embeddings'] + all_dataset_embeddings[random_id]['passage_positive_embedding']\n",
|
||
" all_passages = all_dataset[random_id]['passage_negative'] + all_dataset[random_id]['passage_positive']\n",
|
||
" random_passage_id = random.randint(0, len(all_passages) - 1)\n",
|
||
" random_passage_embeddings = all_passages_embedding[random_passage_id]\n",
|
||
" score_question_random_passage = np.dot(question_embeddings, random_passage_embeddings)\n",
|
||
"\n",
|
||
" if score_question_random_passage < THRESHOLD_MULTIPLY * score_question_passage_positive:\n",
|
||
" all_dataset[id]['passage_negative_random'].append(all_passages[random_passage_id])\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "252b18c2",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"100"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"len(all_dataset[:100])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "d958564d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import json\n",
|
||
"\n",
|
||
"with open(\"./x.json\", 'w', encoding='utf-8') as f:\n",
|
||
" json.dump(all_dataset[:100], f, ensure_ascii=False, indent=4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b86d5fdf",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|