193 lines
4.5 KiB
Plaintext
193 lines
4.5 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "9dbad513",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from datasets import load_dataset\n",
|
|
"\n",
|
|
"ds = load_dataset(\"virattt/financial-qa-10K\", split=\"train\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "7330f385",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n",
|
|
" 'pos': 'Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.',\n",
|
|
" 'id': '0'}"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds = ds.select_columns(column_names=[\"question\", \"context\"])\n",
|
|
"ds = ds.rename_column(\"question\", \"query\")\n",
|
|
"ds = ds.rename_column(\"context\", \"pos\")\n",
|
|
"ds = ds.add_column(\"id\", [str(i) for i in range(len(ds))])\n",
|
|
"ds[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "5ba361dd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"np.random.seed(520)\n",
|
|
"neg_num = 10\n",
|
|
"\n",
|
|
"def str_to_lst(data):\n",
|
|
" data[\"pos\"] = [data[\"pos\"]]\n",
|
|
" return data\n",
|
|
"\n",
|
|
"# sample negative texts\n",
|
|
"new_col = []\n",
|
|
"for i in range(len(ds)):\n",
|
|
" ids = np.random.randint(0, len(ds), size=neg_num)\n",
|
|
" while i in ids:\n",
|
|
" ids = np.random.randint(0, len(ds), size=neg_num)\n",
|
|
" neg = [ds[i.item()][\"pos\"] for i in ids]\n",
|
|
" new_col.append(neg)\n",
|
|
"ds = ds.add_column(\"neg\", new_col)\n",
|
|
"\n",
|
|
"# change the key of 'pos' to a list\n",
|
|
"ds = ds.map(str_to_lst)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "bf3241ca",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"instruction = \"Represent this sentence for searching relevant passages: \"\n",
|
|
"ds = ds.add_column(\"prompt\", [instruction]*len(ds))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "a35c1466",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"split = ds.train_test_split(test_size=0.02, shuffle=True, seed=520)\n",
|
|
"train = split[\"train\"]\n",
|
|
"test = split[\"test\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "aec6787d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"140"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "c5cc42ed",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Creating json from Arrow format: 0%| | 0/7 [00:00<?, ?ba/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 21.58ba/s]\n",
|
|
"Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 148.87ba/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"364936"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train.to_json(\"training.json\")\n",
|
|
"test.to_json(\"test.json\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "536227f7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|