add gemma train
This commit is contained in:
parent
467c21ce7e
commit
bc2cc07411
148
evaluation/evaluate.py
Normal file
148
evaluation/evaluate.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import importlib
|
||||||
|
import tqdm
|
||||||
|
from hazm import Normalizer
|
||||||
|
|
||||||
|
normalizer = Normalizer()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(input_file):
|
||||||
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_ndcg(scores, n):
|
||||||
|
def calculate_dcg(scores, n):
|
||||||
|
idcg = 0
|
||||||
|
for i in range(n):
|
||||||
|
a = (2 ** scores[i]) - 1
|
||||||
|
b = math.log2(i + 2)
|
||||||
|
|
||||||
|
idcg += (a/b)
|
||||||
|
return idcg
|
||||||
|
|
||||||
|
def calculate_idcg(scores, n):
|
||||||
|
new_scores = scores.copy()
|
||||||
|
new_scores.sort(reverse=True)
|
||||||
|
idcg = calculate_dcg(new_scores, n)
|
||||||
|
return idcg
|
||||||
|
|
||||||
|
dcg = calculate_dcg(scores, n)
|
||||||
|
idcg = calculate_idcg(scores, n)
|
||||||
|
ndcg = dcg/idcg
|
||||||
|
return ndcg
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_recall(scores):
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_ground_truth = scores.count(4)
|
||||||
|
if num_ground_truth == 0:
|
||||||
|
num_ground_truth = scores.count(3)
|
||||||
|
|
||||||
|
recall_7 = scores[:7].count(4) / num_ground_truth
|
||||||
|
recall_12 = scores[:12].count(4) / num_ground_truth
|
||||||
|
recall_20 = scores[:20].count(4) / num_ground_truth
|
||||||
|
recall_variant = scores[:scores.count(4)].count(4) / scores.count(4)
|
||||||
|
|
||||||
|
return recall_7, recall_12, recall_20, recall_variant
|
||||||
|
except:
|
||||||
|
return 0, 0, 0, 0
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_precision(scores):
|
||||||
|
precision_7 = scores[:7].count(4) / 7
|
||||||
|
precision_12 = scores[:12].count(4) / 12
|
||||||
|
precision_20 = scores[:20].count(4) / 20
|
||||||
|
|
||||||
|
return precision_7, precision_12, precision_20
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_reranker(text:str, preprocess:bool=True, add_extra_word:bool=False):
|
||||||
|
if preprocess:
|
||||||
|
text = text.replace("\n", ".")
|
||||||
|
text = normalizer.normalize(text)
|
||||||
|
|
||||||
|
if add_extra_word:
|
||||||
|
text += " رهبر انقلاب اسلامی حضرت امام خامنه ای "
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def run(input_file, model):
|
||||||
|
module = importlib.import_module("evaluation.models." + model)
|
||||||
|
model = module.model()
|
||||||
|
|
||||||
|
ndcg_scores = []
|
||||||
|
recall_7_scores = []
|
||||||
|
recall_12_scores = []
|
||||||
|
recall_20_scores = []
|
||||||
|
recall_variant_scores = []
|
||||||
|
precision_7_scores = []
|
||||||
|
precision_12_scores = []
|
||||||
|
precision_20_scores = []
|
||||||
|
dataset = load_dataset(input_file)
|
||||||
|
for count, data in enumerate(tqdm.tqdm(dataset)):
|
||||||
|
question = data["question"]
|
||||||
|
chunks = [data["chunks"][str(id)] for id in range(len(data["chunks"].keys()))]
|
||||||
|
scores_llm = [data["scores"][str(id)] for id in range(len(data["chunks"].keys()))]
|
||||||
|
scores_embed = []
|
||||||
|
for chunk in chunks:
|
||||||
|
scores_embed.append(model.run(preprocess_reranker(question, preprocess=True), preprocess_reranker(chunk, preprocess=True, add_extra_word=False)))
|
||||||
|
|
||||||
|
# print(f"question {count}: {question}")
|
||||||
|
# for i in range(len(scores_embed)):
|
||||||
|
# print(f"chunk {i}: scores_embed {scores_embed[i]}, scores_llm {scores_llm[i]}")
|
||||||
|
# print("--------------------------------\n")
|
||||||
|
sorted_pairs = sorted(zip(scores_embed, scores_llm), reverse=True)
|
||||||
|
scores = [rel for _, rel in sorted_pairs]
|
||||||
|
#calculate ndcg
|
||||||
|
ndcg = calculate_ndcg(scores, len(scores))
|
||||||
|
ndcg_scores.append(ndcg)
|
||||||
|
|
||||||
|
#calculate recall
|
||||||
|
recall_7, recall_12, recall_20, recall_variant = calculate_recall(scores)
|
||||||
|
recall_7_scores.append(recall_7)
|
||||||
|
recall_12_scores.append(recall_12)
|
||||||
|
recall_20_scores.append(recall_20)
|
||||||
|
recall_variant_scores.append(recall_variant)
|
||||||
|
|
||||||
|
#calculate precision
|
||||||
|
precision_7, precision_12, precision_20 = calculate_precision(scores)
|
||||||
|
precision_7_scores.append(precision_7)
|
||||||
|
precision_12_scores.append(precision_12)
|
||||||
|
precision_20_scores.append(precision_20)
|
||||||
|
|
||||||
|
print(f"NDCG: {sum(ndcg_scores)/len(ndcg_scores)}")
|
||||||
|
print(f"Recall 7: {sum(recall_7_scores)/len(recall_7_scores)}")
|
||||||
|
print(f"Recall 12: {sum(recall_12_scores)/len(recall_12_scores)}")
|
||||||
|
print(f"Recall 20: {sum(recall_20_scores)/len(recall_20_scores)}")
|
||||||
|
print(f"Recall Variant: {sum(recall_variant_scores)/len(recall_variant_scores)}")
|
||||||
|
print(f"Precision 7: {sum(precision_7_scores)/len(precision_7_scores)}")
|
||||||
|
print(f"Precision 12: {sum(precision_12_scores)/len(precision_12_scores)}")
|
||||||
|
print(f"Precision 20: {sum(precision_20_scores)/len(precision_20_scores)}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
-First give your questions to generate_dataset.py and generate a json file and give the path as input_file.
|
||||||
|
-Second create your model class in ./models folder similar to sample_model.py
|
||||||
|
-Third run the script with the following command:
|
||||||
|
python evaluate.py --input_file <path_to_your_json_file> --model <path_to_your_model_class>
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--input_file', help='json input file path')
|
||||||
|
parser.add_argument('--model', help='the path of model class')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Start to evaluate the model {args.model} with normalizer and extra words input file {args.input_file}")
|
||||||
|
run(args.input_file, args.model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
148
evaluation/evaluate_50.py
Normal file
148
evaluation/evaluate_50.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import importlib
|
||||||
|
import tqdm
|
||||||
|
from hazm import Normalizer
|
||||||
|
|
||||||
|
normalizer = Normalizer()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(input_file):
|
||||||
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_ndcg(scores, n):
|
||||||
|
def calculate_dcg(scores, n):
|
||||||
|
idcg = 0
|
||||||
|
for i in range(n):
|
||||||
|
a = (2 ** scores[i]) - 1
|
||||||
|
b = math.log2(i + 2)
|
||||||
|
|
||||||
|
idcg += (a/b)
|
||||||
|
return idcg
|
||||||
|
|
||||||
|
def calculate_idcg(scores, n):
|
||||||
|
new_scores = scores.copy()
|
||||||
|
new_scores.sort(reverse=True)
|
||||||
|
idcg = calculate_dcg(new_scores, n)
|
||||||
|
return idcg
|
||||||
|
|
||||||
|
dcg = calculate_dcg(scores, n)
|
||||||
|
idcg = calculate_idcg(scores, n)
|
||||||
|
ndcg = dcg/idcg
|
||||||
|
return ndcg
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_recall(scores):
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_ground_truth = scores.count(4)
|
||||||
|
if num_ground_truth == 0:
|
||||||
|
num_ground_truth = scores.count(3)
|
||||||
|
|
||||||
|
recall_7 = scores[:7].count(4) / num_ground_truth
|
||||||
|
recall_12 = scores[:12].count(4) / num_ground_truth
|
||||||
|
recall_20 = scores[:20].count(4) / num_ground_truth
|
||||||
|
recall_variant = scores[:scores.count(4)].count(4) / scores.count(4)
|
||||||
|
|
||||||
|
return recall_7, recall_12, recall_20, recall_variant
|
||||||
|
except:
|
||||||
|
return 0, 0, 0, 0
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_precision(scores):
|
||||||
|
precision_7 = scores[:7].count(4) / 7
|
||||||
|
precision_12 = scores[:12].count(4) / 12
|
||||||
|
precision_20 = scores[:20].count(4) / 20
|
||||||
|
|
||||||
|
return precision_7, precision_12, precision_20
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_reranker(text:str, preprocess:bool=True, add_extra_word:bool=False):
|
||||||
|
if preprocess:
|
||||||
|
text = text.replace("\n", ".")
|
||||||
|
text = normalizer.normalize(text)
|
||||||
|
|
||||||
|
if add_extra_word:
|
||||||
|
text += " رهبر انقلاب اسلامی حضرت امام خامنه ای "
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def run(input_file, model):
|
||||||
|
module = importlib.import_module("evaluation.models." + model)
|
||||||
|
model = module.model()
|
||||||
|
|
||||||
|
ndcg_scores = []
|
||||||
|
recall_7_scores = []
|
||||||
|
recall_12_scores = []
|
||||||
|
recall_20_scores = []
|
||||||
|
recall_variant_scores = []
|
||||||
|
precision_7_scores = []
|
||||||
|
precision_12_scores = []
|
||||||
|
precision_20_scores = []
|
||||||
|
dataset = load_dataset(input_file)
|
||||||
|
for count, data in enumerate(tqdm.tqdm(dataset)):
|
||||||
|
question = data["question"]
|
||||||
|
chunks = [data["chunks"][str(id)] for id in range(len(data["chunks"].keys()))]
|
||||||
|
scores_llm = [data["scores"][str(id)] for id in range(len(data["chunks"].keys()))]
|
||||||
|
scores_embed = []
|
||||||
|
for chunk in chunks:
|
||||||
|
scores_embed.append(model.run(preprocess_reranker(question, preprocess=True), preprocess_reranker(chunk, preprocess=True, add_extra_word=False)))
|
||||||
|
|
||||||
|
# print(f"question {count}: {question}")
|
||||||
|
# for i in range(len(scores_embed)):
|
||||||
|
# print(f"chunk {i}: scores_embed {scores_embed[i]}, scores_llm {scores_llm[i]}")
|
||||||
|
# print("--------------------------------\n")
|
||||||
|
sorted_pairs = sorted(zip(scores_embed, scores_llm), reverse=True)
|
||||||
|
scores = [rel for _, rel in sorted_pairs]
|
||||||
|
#calculate ndcg
|
||||||
|
ndcg = calculate_ndcg(scores, len(scores))
|
||||||
|
ndcg_scores.append(ndcg)
|
||||||
|
|
||||||
|
#calculate recall
|
||||||
|
recall_7, recall_12, recall_20, recall_variant = calculate_recall(scores)
|
||||||
|
recall_7_scores.append(recall_7)
|
||||||
|
recall_12_scores.append(recall_12)
|
||||||
|
recall_20_scores.append(recall_20)
|
||||||
|
recall_variant_scores.append(recall_variant)
|
||||||
|
|
||||||
|
#calculate precision
|
||||||
|
precision_7, precision_12, precision_20 = calculate_precision(scores)
|
||||||
|
precision_7_scores.append(precision_7)
|
||||||
|
precision_12_scores.append(precision_12)
|
||||||
|
precision_20_scores.append(precision_20)
|
||||||
|
|
||||||
|
print(f"NDCG: {sum(ndcg_scores)/len(ndcg_scores)}")
|
||||||
|
print(f"Recall 7: {sum(recall_7_scores)/len(recall_7_scores)}")
|
||||||
|
print(f"Recall 12: {sum(recall_12_scores)/len(recall_12_scores)}")
|
||||||
|
print(f"Recall 20: {sum(recall_20_scores)/len(recall_20_scores)}")
|
||||||
|
print(f"Recall Variant: {sum(recall_variant_scores)/len(recall_variant_scores)}")
|
||||||
|
print(f"Precision 7: {sum(precision_7_scores)/len(precision_7_scores)}")
|
||||||
|
print(f"Precision 12: {sum(precision_12_scores)/len(precision_12_scores)}")
|
||||||
|
print(f"Precision 20: {sum(precision_20_scores)/len(precision_20_scores)}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
-First give your questions to generate_dataset.py and generate a json file and give the path as input_file.
|
||||||
|
-Second create your model class in ./models folder similar to sample_model.py
|
||||||
|
-Third run the script with the following command:
|
||||||
|
python evaluate.py --input_file <path_to_your_json_file> --model <path_to_your_model_class>
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--input_file', help='json input file path')
|
||||||
|
parser.add_argument('--model', help='the path of model class')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Start to evaluate the model {args.model} with normalizer and extra words input file {args.input_file}")
|
||||||
|
run(args.input_file, args.model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
34
notes.txt
34
notes.txt
@ -18,4 +18,36 @@
|
|||||||
|
|
||||||
9-longragfa dataset: it is long doc and query and for evaluation : question = 250, passage = 1500 : not using
|
9-longragfa dataset: it is long doc and query and for evaluation : question = 250, passage = 1500 : not using
|
||||||
|
|
||||||
10-Synthetic-persian-qa-retrieval dataset : question = 223423, passage = 250000 : negetaive passage are not exactly different : needs preprocessing
|
10-Synthetic-persian-qa-retrieval dataset : question = 223423, passage = 250000 : negetaive passage are not exactly different : needs preprocessing
|
||||||
|
|
||||||
|
no train
|
||||||
|
NDCG: 0.8452119768348717
|
||||||
|
Recall 7: 0.3373666606161222
|
||||||
|
Recall 12: 0.48390155482482855
|
||||||
|
Recall 20: 0.6340810809380268
|
||||||
|
Recall Variant: 0.44313617731261423
|
||||||
|
Precision 7: 0.4714285714285715
|
||||||
|
Precision 12: 0.41999999999999993
|
||||||
|
Precision 20: 0.358
|
||||||
|
|
||||||
|
train with 100
|
||||||
|
NDCG: 0.8007791818263832
|
||||||
|
Recall 7: 0.2617863643550479
|
||||||
|
Recall 12: 0.3759745806720163
|
||||||
|
Recall 20: 0.5564983103150418
|
||||||
|
Recall Variant: 0.36642345327979325
|
||||||
|
Precision 7: 0.3828571428571429
|
||||||
|
Precision 12: 0.3449999999999999
|
||||||
|
Precision 20: 0.311
|
||||||
|
|
||||||
|
train with 100 with lora
|
||||||
|
NDCG: 0.8432282495018343
|
||||||
|
Recall 7: 0.33695911259587386
|
||||||
|
Recall 12: 0.4729916144600827
|
||||||
|
Recall 20: 0.6212526155736547
|
||||||
|
Recall Variant: 0.43208929205133273
|
||||||
|
Precision 7: 0.4685714285714285
|
||||||
|
Precision 12: 0.4099999999999999
|
||||||
|
Precision 20: 0.35200000000000004
|
||||||
|
|
||||||
|
train with 100 with promt
|
||||||
|
|||||||
@ -11,10 +11,7 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
"/home/firouzi/embedding_model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||||
"Downloading readme: 100%|██████████| 419/419 [00:00<00:00, 1.18MB/s]\n",
|
|
||||||
"Downloading data: 100%|██████████| 1.59M/1.59M [00:01<00:00, 1.03MB/s]\n",
|
|
||||||
"Generating train split: 100%|██████████| 7000/7000 [00:00<00:00, 175360.77 examples/s]\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -56,15 +53,7 @@
|
|||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"id": "5ba361dd",
|
"id": "5ba361dd",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Map: 100%|██████████| 7000/7000 [00:00<00:00, 19176.72 examples/s]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -102,48 +91,78 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 12,
|
||||||
"id": "a35c1466",
|
"id": "a35c1466",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)\n",
|
"split = ds.train_test_split(test_size=0.02, shuffle=True, seed=520)\n",
|
||||||
"train = split[\"train\"]\n",
|
"train = split[\"train\"]\n",
|
||||||
"test = split[\"test\"]"
|
"test = split[\"test\"]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 13,
|
||||||
"id": "24f3f7fb",
|
"id": "aec6787d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"140"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"len(test)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "c5cc42ed",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 26.22ba/s]\n"
|
"Creating json from Arrow format: 0%| | 0/7 [00:00<?, ?ba/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 21.58ba/s]\n",
|
||||||
|
"Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 148.87ba/s]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"16583481"
|
"364936"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 14,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"train.to_json(\"training.json\")"
|
"train.to_json(\"training.json\")\n",
|
||||||
|
"test.to_json(\"test.json\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "c5cc42ed",
|
"id": "536227f7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
|
|||||||
@ -149,10 +149,39 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 2,
|
||||||
"id": "1fabd9d8",
|
"id": "1fabd9d8",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"import random\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"/home/firouzi/embedding_model/data/train.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
||||||
|
" all_dataset = json.load(f)\n",
|
||||||
|
"\n",
|
||||||
|
"test_ratio = 0.05 # 30% test data\n",
|
||||||
|
"\n",
|
||||||
|
"# Shuffle and split\n",
|
||||||
|
"random.shuffle(all_dataset)\n",
|
||||||
|
"split_index = int(len(all_dataset) * (1 - test_ratio))\n",
|
||||||
|
"train_data = all_dataset[:split_index]\n",
|
||||||
|
"test_data = all_dataset[split_index:]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"/home/firouzi/embedding_model/data/train_test.json\", \"w\", encoding=\"utf-8\") as f:\n",
|
||||||
|
" json.dump(test_data, f, ensure_ascii=False, indent=4)\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"/home/firouzi/embedding_model/data/train_train.json\", \"w\", encoding=\"utf-8\") as f:\n",
|
||||||
|
" json.dump(train_data, f, ensure_ascii=False, indent=4)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8be4382d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
23
train/gemma/gemma_inference.py
Normal file
23
train/gemma/gemma_inference.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import torch
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
# model_id = "google/embeddinggemma-300M"
|
||||||
|
model_id = "my-embedding-gemma/checkpoint-15"
|
||||||
|
model = SentenceTransformer(model_id).to(device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_scores(query, document):
|
||||||
|
query_embedding = model.encode_query(query)
|
||||||
|
doc_embedding = model.encode_document(document)
|
||||||
|
|
||||||
|
# Calculate the embedding similarities
|
||||||
|
similarities = model.similarity(query_embedding, doc_embedding)
|
||||||
|
|
||||||
|
print(similarities)
|
||||||
|
|
||||||
|
query = "I want to start a tax-free installment investment, what should I do?"
|
||||||
|
documents = "Opening a NISA Account"
|
||||||
|
|
||||||
|
get_scores(query, documents)
|
||||||
95
train/gemma/gemma_train.py
Normal file
95
train/gemma/gemma_train.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import json
|
||||||
|
from datasets import Dataset
|
||||||
|
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
||||||
|
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
||||||
|
import argparse
|
||||||
|
from peft import LoraConfig, TaskType
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
|
def get_ndcg(model, dataset):
|
||||||
|
query_embeddings = model.encode_query("hey")
|
||||||
|
print(query_embeddings[:20])
|
||||||
|
|
||||||
|
|
||||||
|
def main(add_prompt, lora):
|
||||||
|
########### Load dataset ###########
|
||||||
|
print("loading dataset")
|
||||||
|
with open("/home/firouzi/embedding_model/data/dataset_train.json", "r", encoding="utf-8") as f:
|
||||||
|
all_dataset = json.load(f)
|
||||||
|
|
||||||
|
query_prompt = "task: search result | query: "
|
||||||
|
document_prompt = "title: none | text: "
|
||||||
|
|
||||||
|
data_as_dicts = []
|
||||||
|
for data in all_dataset:
|
||||||
|
for data_neg in (data["passage_negative"] + data["passage_negative_random"]):
|
||||||
|
if add_prompt:
|
||||||
|
data_as_dicts.append({"anchor": query_prompt + data["question"], "positive": document_prompt + data["passage_positive"][0], "negative": document_prompt + data_neg})
|
||||||
|
else:
|
||||||
|
data_as_dicts.append({"anchor": data["question"], "positive": data["passage_positive"][0], "negative": data_neg})
|
||||||
|
|
||||||
|
train_dataset = Dataset.from_list(data_as_dicts)
|
||||||
|
print(f"len train_dataset: {len(train_dataset)}")
|
||||||
|
|
||||||
|
####################################
|
||||||
|
print("loading model")
|
||||||
|
model = SentenceTransformer("google/embeddinggemma-300M").to(device="cuda:0")
|
||||||
|
|
||||||
|
if lora:
|
||||||
|
# Create a LoRA adapter for the model
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
task_type=TaskType.FEATURE_EXTRACTION,
|
||||||
|
inference_mode=False,
|
||||||
|
r=64,
|
||||||
|
lora_alpha=128,
|
||||||
|
lora_dropout=0.1,
|
||||||
|
)
|
||||||
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
|
loss = MultipleNegativesRankingLoss(model)
|
||||||
|
|
||||||
|
args = SentenceTransformerTrainingArguments(
|
||||||
|
output_dir="./models/gemma",
|
||||||
|
num_train_epochs=1,
|
||||||
|
per_device_train_batch_size=32,
|
||||||
|
learning_rate=2e-5,
|
||||||
|
warmup_ratio=0.05,
|
||||||
|
logging_steps=train_dataset.num_rows,
|
||||||
|
report_to="none",
|
||||||
|
save_steps=10000,
|
||||||
|
save_total_limit=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MyCallback(TrainerCallback):
|
||||||
|
"A callback that evaluates the model at the end of eopch"
|
||||||
|
def __init__(self, evaluate):
|
||||||
|
self.evaluate = evaluate # evaluate function
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, **kwargs):
|
||||||
|
# Evaluate the model using text generation
|
||||||
|
print(f"Step {state.global_step} finished. Running evaluation:")
|
||||||
|
self.evaluate()
|
||||||
|
|
||||||
|
def evaluate():
|
||||||
|
get_ndcg(model, train_dataset)
|
||||||
|
|
||||||
|
print("start to training model...")
|
||||||
|
trainer = SentenceTransformerTrainer(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
loss=loss,
|
||||||
|
# callbacks=[MyCallback(evaluate)]
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
print("training done")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--add_prompt", action="store_true")
|
||||||
|
parser.add_argument("--lora", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args.lora)
|
||||||
|
main(args.add_prompt, args.lora)
|
||||||
92
train/gemma/gemma_train_small.py
Normal file
92
train/gemma/gemma_train_small.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import torch
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# Login into Hugging Face Hub
|
||||||
|
# from huggingface_hub import login
|
||||||
|
# login()
|
||||||
|
|
||||||
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
model_id = "google/embeddinggemma-300M"
|
||||||
|
model = SentenceTransformer(model_id).to(device=device)
|
||||||
|
|
||||||
|
print(f"Device: {model.device}")
|
||||||
|
print(model)
|
||||||
|
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
dataset = [
|
||||||
|
["How do I open a NISA account?", "What is the procedure for starting a new tax-free investment account?", "I want to check the balance of my regular savings account."],
|
||||||
|
["Are there fees for making an early repayment on a home loan?", "If I pay back my house loan early, will there be any costs?", "What is the management fee for this investment trust?"],
|
||||||
|
["What is the coverage for medical insurance?", "Tell me about the benefits of the health insurance plan.", "What is the cancellation policy for my life insurance?"],
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert the list-based dataset into a list of dictionaries.
|
||||||
|
data_as_dicts = [ {"anchor": row[0], "positive": row[1], "negative": row[2]} for row in dataset ]
|
||||||
|
|
||||||
|
# Create a Hugging Face `Dataset` object from the list of dictionaries.
|
||||||
|
train_dataset = Dataset.from_list(data_as_dicts)
|
||||||
|
print(train_dataset)
|
||||||
|
|
||||||
|
task_name = "STS"
|
||||||
|
|
||||||
|
def get_scores(query, documents):
|
||||||
|
# Calculate embeddings by calling model.encode()
|
||||||
|
query_embeddings = model.encode(query, prompt_name=task_name)
|
||||||
|
doc_embeddings = model.encode(documents, prompt_name=task_name)
|
||||||
|
|
||||||
|
# Calculate the embedding similarities
|
||||||
|
similarities = model.similarity(query_embeddings, doc_embeddings)
|
||||||
|
|
||||||
|
for idx, doc in enumerate(documents):
|
||||||
|
print("Document: ", doc, "-> 🤖 Score: ", similarities.numpy()[0][idx])
|
||||||
|
|
||||||
|
query = "I want to start a tax-free installment investment, what should I do?"
|
||||||
|
documents = ["Opening a NISA Account", "Opening a Regular Savings Account", "Home Loan Application Guide"]
|
||||||
|
|
||||||
|
get_scores(query, documents)
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
||||||
|
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
loss = MultipleNegativesRankingLoss(model)
|
||||||
|
|
||||||
|
args = SentenceTransformerTrainingArguments(
|
||||||
|
# Required parameter:
|
||||||
|
output_dir="my-embedding-gemma",
|
||||||
|
# Optional training parameters:
|
||||||
|
prompts=model.prompts[task_name], # use model's prompt to train
|
||||||
|
num_train_epochs=5,
|
||||||
|
per_device_train_batch_size=1,
|
||||||
|
learning_rate=2e-5,
|
||||||
|
warmup_ratio=0.1,
|
||||||
|
# Optional tracking/debugging parameters:
|
||||||
|
logging_steps=train_dataset.num_rows,
|
||||||
|
report_to="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
class MyCallback(TrainerCallback):
|
||||||
|
"A callback that evaluates the model at the end of eopch"
|
||||||
|
def __init__(self, evaluate):
|
||||||
|
self.evaluate = evaluate # evaluate function
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, **kwargs):
|
||||||
|
# Evaluate the model using text generation
|
||||||
|
print(f"Step {state.global_step} finished. Running evaluation:")
|
||||||
|
self.evaluate()
|
||||||
|
|
||||||
|
def evaluate():
|
||||||
|
get_scores(query, documents)
|
||||||
|
|
||||||
|
trainer = SentenceTransformerTrainer(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
loss=loss,
|
||||||
|
callbacks=[MyCallback(evaluate)]
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
get_scores(query, documents)
|
||||||
Loading…
x
Reference in New Issue
Block a user