{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "9dbad513", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "Downloading readme: 100%|██████████| 419/419 [00:00<00:00, 1.18MB/s]\n", "Downloading data: 100%|██████████| 1.59M/1.59M [00:01<00:00, 1.03MB/s]\n", "Generating train split: 100%|██████████| 7000/7000 [00:00<00:00, 175360.77 examples/s]\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "ds = load_dataset(\"virattt/financial-qa-10K\", split=\"train\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "7330f385", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n", " 'pos': 'Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.',\n", " 'id': '0'}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds = ds.select_columns(column_names=[\"question\", \"context\"])\n", "ds = ds.rename_column(\"question\", \"query\")\n", "ds = ds.rename_column(\"context\", \"pos\")\n", "ds = ds.add_column(\"id\", [str(i) for i in range(len(ds))])\n", "ds[0]" ] }, { "cell_type": "code", "execution_count": 3, "id": "5ba361dd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 7000/7000 [00:00<00:00, 19176.72 examples/s]\n" ] } ], "source": [ "import numpy as np\n", "\n", "np.random.seed(520)\n", "neg_num = 10\n", "\n", "def str_to_lst(data):\n", " data[\"pos\"] = [data[\"pos\"]]\n", " return data\n", "\n", "# sample negative texts\n", "new_col = []\n", "for i in range(len(ds)):\n", " ids = np.random.randint(0, len(ds), size=neg_num)\n", " while i in ids:\n", " ids = np.random.randint(0, len(ds), size=neg_num)\n", " neg = [ds[i.item()][\"pos\"] for i in ids]\n", " new_col.append(neg)\n", "ds = ds.add_column(\"neg\", new_col)\n", "\n", "# change the key of 'pos' to a list\n", "ds = ds.map(str_to_lst)" ] }, { "cell_type": "code", "execution_count": 4, "id": "bf3241ca", "metadata": {}, "outputs": [], "source": [ "instruction = \"Represent this sentence for searching relevant passages: \"\n", "ds = ds.add_column(\"prompt\", [instruction]*len(ds))" ] }, { "cell_type": "code", "execution_count": 5, "id": "a35c1466", "metadata": {}, "outputs": [], "source": [ "split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)\n", "train = split[\"train\"]\n", "test = split[\"test\"]" ] }, { "cell_type": "code", "execution_count": 6, "id": "24f3f7fb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 26.22ba/s]\n" ] }, { "data": { "text/plain": [ "16583481" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.to_json(\"training.json\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c5cc42ed", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }