228 lines
8.5 KiB
Python
228 lines
8.5 KiB
Python
import argparse
|
|
from datasets import load_dataset
|
|
import json
|
|
from tqdm import tqdm
|
|
|
|
from data_preprocess.remove_false_negative_model import LLMModel
|
|
from data_preprocess.generate_random_negative_sample import generate_random_negative_sample
|
|
|
|
|
|
llm_model = LLMModel()
|
|
|
|
def load_synthetic_dataset(synthetic_train_path, synthetic_queries_path, synthetic_corpus_path):
|
|
"""
|
|
load synthetic dataset from local jsonl files
|
|
output:
|
|
[{
|
|
"question": "",
|
|
"passgae_positive": [],
|
|
"passgae_negative": [],
|
|
"passage_negative_random": []
|
|
}]
|
|
"""
|
|
dataset_synthetic_scores = []
|
|
with open(synthetic_train_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
dataset_synthetic_scores.append(data)
|
|
|
|
dataset_synthetic_queries = {}
|
|
with open(synthetic_queries_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
json_data = json.loads(line)
|
|
dataset_synthetic_queries[json_data['_id']] = json_data
|
|
|
|
dataset_synthetic_corpus = {}
|
|
with open(synthetic_corpus_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
json_data = json.loads(line)
|
|
dataset_synthetic_corpus[json_data['_id']] = json_data
|
|
|
|
#create a json which has question, passgae_positive, passgae_negative, passage_negative_random
|
|
all_dataset = {}
|
|
for data_topic in dataset_synthetic_scores:
|
|
|
|
query_id = data_topic['query-id']
|
|
corpus_id = int(data_topic['corpus-id'])
|
|
score = data_topic['score']
|
|
|
|
if (query_id in dataset_synthetic_queries) and (corpus_id in dataset_synthetic_corpus):
|
|
if score == "1":
|
|
if query_id in all_dataset:
|
|
all_dataset[query_id]['passage_positive'].append(dataset_synthetic_corpus[corpus_id]['text'])
|
|
else:
|
|
all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'],
|
|
'passage_positive': [dataset_synthetic_corpus[corpus_id]['text']],
|
|
'passage_negative': [],
|
|
'passage_negative_random': []}
|
|
else:
|
|
if query_id in all_dataset:
|
|
all_dataset[query_id]['passage_negative'].append(dataset_synthetic_corpus[corpus_id]['text'])
|
|
else:
|
|
all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'],
|
|
'passage_positive': [],
|
|
'passage_negative': [dataset_synthetic_corpus[corpus_id]['text']],
|
|
'passage_negative_random': []}
|
|
|
|
all_dataset = list(all_dataset.values())
|
|
|
|
return all_dataset
|
|
|
|
|
|
def load_pquad_dataset():
|
|
"""
|
|
load pquad dataset from huggingface
|
|
output:
|
|
[{
|
|
"question": "",
|
|
"passage_positive": [],
|
|
"passage_negative": [],
|
|
"passage_negative_random": []
|
|
}]
|
|
"""
|
|
dataset = load_dataset("Gholamreza/pquad", trust_remote_code=True)
|
|
|
|
all_dataset = []
|
|
for data in dataset["train"]:
|
|
if len(data["answers"]["text"]) > 0:
|
|
all_dataset.append({'question': data['question'], 'passage_positive': [data['context']], 'passage_negative': [], 'passage_negative_random': []})
|
|
|
|
return all_dataset
|
|
|
|
|
|
def remove_false_negative(dataset, random_negative_sample=False):
|
|
"""
|
|
remove false negative samples from synthetic dataset
|
|
Args:
|
|
dataset: list of dicts
|
|
Returns:
|
|
dataset: list of dicts
|
|
"""
|
|
if random_negative_sample:
|
|
negative_name = "passage_negative_random"
|
|
else:
|
|
negative_name = "passage_negative"
|
|
|
|
# calculate passage negative embeddings
|
|
negative_count_all = 0
|
|
negative_count_removed = 0
|
|
len_dataset = len(dataset)
|
|
batch_size = 50
|
|
for i in tqdm(range(0, len_dataset, batch_size)):
|
|
|
|
question_list = []
|
|
passage_negative_list = []
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage in dataset[id][negative_name]:
|
|
question_list.append(dataset[id]['question'])
|
|
passage_negative_list.append(passage)
|
|
|
|
results = llm_model.remove_false_negative_llm(question_list, passage_negative_list)
|
|
|
|
negative_count_removed += len([_ for _ in results if _ == "1"])
|
|
negative_count_all += len(results)
|
|
|
|
count = 0
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
new_negative_list = []
|
|
for passage_id in range(len(dataset[id][negative_name])):
|
|
if results[count] == "0":
|
|
new_negative_list.append(dataset[id][negative_name][passage_id])
|
|
count += 1
|
|
dataset[id][negative_name] = new_negative_list
|
|
|
|
print(f"removed {negative_count_removed} false negative samples from {negative_count_all} samples")
|
|
print("--------------------------------")
|
|
|
|
return dataset
|
|
|
|
|
|
def save_dataset(dataset, output_path):
|
|
"""
|
|
save dataset to json file
|
|
Args:
|
|
dataset: list of dicts
|
|
output_path: path to save dataset
|
|
"""
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(dataset, f, ensure_ascii=False, indent=4)
|
|
|
|
|
|
def main(output_path):
|
|
|
|
#load synthetic dataset
|
|
print("--------------------------------")
|
|
print("loading synthetic dataset")
|
|
synthetic_train_path = "/home/firouzi/embedding_model/research_notebook/data/synthetic-persian-qa-retrieval/train.jsonl"
|
|
synthetic_corpus_path = "/home/firouzi/embedding_model/research_notebook/data/synthetic-persian-qa-retrieval/corpus.jsonl"
|
|
synthetic_queries_path = "/home/firouzi/embedding_model/research_notebook/data/synthetic-persian-qa-retrieval/queries.jsonl"
|
|
|
|
synthetic_dataset = load_synthetic_dataset(synthetic_train_path, synthetic_queries_path, synthetic_corpus_path)
|
|
print(f"synthetic dataset loaded : {len(synthetic_dataset)} samples")
|
|
print("--------------------------------")
|
|
|
|
#load pquad dataset
|
|
print("loading pquad dataset")
|
|
pquad_dataset = load_pquad_dataset()
|
|
print(f"pquad dataset loaded : {len(pquad_dataset)} samples")
|
|
print("--------------------------------")
|
|
|
|
# merge synthetic and pquad dataset
|
|
print("start to merge synthetic and pquad dataset")
|
|
all_dataset = synthetic_dataset + pquad_dataset
|
|
print(f"successfully merged synthetic and pquad dataset")
|
|
print("--------------------------------")
|
|
|
|
# # removing false negative samples from all dataset
|
|
# print("start to remove false negative samples from all dataset")
|
|
# all_dataset = remove_false_negative(all_dataset, random_negative_sample=False)
|
|
# print(f"successfully removed false negative samples from all dataset")
|
|
# print("--------------------------------")
|
|
|
|
# with open("/home/firouzi/embedding_model/data/train.json", "r", encoding="utf-8") as f:
|
|
# all_dataset = json.load(f)
|
|
|
|
# for i in range(len(all_dataset)):
|
|
# all_dataset[i]['passage_negative_random'] = []
|
|
|
|
#generate random negative samples
|
|
print("start to generate random negative samples")
|
|
all_dataset = generate_random_negative_sample(all_dataset)
|
|
print(f"successfully generated random negative samples")
|
|
print("--------------------------------")
|
|
|
|
# removing random false negative samples from all dataset
|
|
print("start to remove random false negative samples from all dataset")
|
|
all_dataset = remove_false_negative(all_dataset, random_negative_sample=True)
|
|
print(f"successfully removed random false negative samples from all dataset")
|
|
print("--------------------------------")
|
|
|
|
# save dataset
|
|
print("start to save dataset")
|
|
save_dataset(all_dataset, output_path)
|
|
print(f"successfully saved dataset")
|
|
print("--------------------------------")
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
preprocess dataset for training
|
|
|
|
pipelines:
|
|
load synthetic dataset from local jsonl files
|
|
load pquad dataset from huggingface
|
|
remove false negative samples from synthetic dataset
|
|
remove false negative samples from pquad dataset
|
|
merge synthetic and pquad dataset
|
|
generate random negative samples
|
|
save dataset to json file
|
|
|
|
|
|
python preprocess_v1.py --output_path /home/firouzi/embedding_model/data/train.json
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--output_path", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
output_path = args.output_path
|
|
|
|
main(output_path) |