179 lines
4.0 KiB
Plaintext
179 lines
4.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "16798408",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"\n",
|
|
"with open(\"/home/firouzi/embedding_model/data/train_100.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
|
" all_dataset = json.load(f)\n",
|
|
"\n",
|
|
"data_count = []\n",
|
|
"for data in all_dataset:\n",
|
|
" data_count.append(len(data[\"passage_negative\"]) + len(data[\"passage_negative_random\"]))\n",
|
|
"\n",
|
|
"\n",
|
|
"counts = {}\n",
|
|
"\n",
|
|
"for num in data_count:\n",
|
|
" if num in counts:\n",
|
|
" counts[num] += 1\n",
|
|
" else:\n",
|
|
" counts[num] = 1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "a0eb428f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{8: 22, 6: 11, 7: 20, 9: 46, 5: 1}"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"counts"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "ca0412d2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from datasets import Dataset\n",
|
|
"\n",
|
|
"with open(\"/home/firouzi/embedding_model/data/train_100.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
|
" all_dataset = json.load(f)\n",
|
|
"\n",
|
|
"anchors = []\n",
|
|
"positives = []\n",
|
|
"negatives_1 = []\n",
|
|
"negatives_2 = []\n",
|
|
"negatives_3 = []\n",
|
|
"negatives_4 = []\n",
|
|
"negatives_5 = []\n",
|
|
"for data in all_dataset:\n",
|
|
" anchors.append(data[\"question\"])\n",
|
|
" positives.append(data[\"passage_positive\"])\n",
|
|
" all_negatives = data[\"passage_negative\"] + data[\"passage_negative_random\"]\n",
|
|
" if len(all_negatives) < 5:\n",
|
|
" for i in range(5 - len(all_negatives)):\n",
|
|
" all_negatives.append(all_negatives[0])\n",
|
|
" negatives_1.append(all_negatives[0])\n",
|
|
" negatives_2.append(all_negatives[1])\n",
|
|
" negatives_3.append(all_negatives[2])\n",
|
|
" negatives_4.append(all_negatives[3])\n",
|
|
" negatives_5.append(all_negatives[4])\n",
|
|
"\n",
|
|
"dataset = Dataset.from_dict({\n",
|
|
" \"anchor\": anchors,\n",
|
|
" \"positive\": positives,\n",
|
|
" \"negative_1\": negatives_1,\n",
|
|
" \"negative_2\": negatives_2,\n",
|
|
" \"negative_3\": negatives_3,\n",
|
|
" \"negative_4\": negatives_4,\n",
|
|
" \"negative_5\": negatives_5,\n",
|
|
"})\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "cc963d18",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset_split = dataset.train_test_split(test_size=0.05, seed=42)\n",
|
|
"\n",
|
|
"train_dataset = dataset_split[\"train\"]\n",
|
|
"test_dataset = dataset_split[\"test\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "593f7ce4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"95"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(train_dataset)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "f0443056",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"5"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(test_dataset)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "377f53ba",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|