embedding_model/data_preprocess/preprocess_v1.py
2025-11-09 13:44:28 +00:00

228 lines
8.5 KiB
Python

import argparse
from datasets import load_dataset
import json
from tqdm import tqdm
from data_preprocess.remove_false_negative_model import LLMModel
from data_preprocess.generate_random_negative_sample import generate_random_negative_sample
llm_model = LLMModel()
def load_synthetic_dataset(synthetic_train_path, synthetic_queries_path, synthetic_corpus_path):
"""
load synthetic dataset from local jsonl files
output:
[{
"question": "",
"passgae_positive": [],
"passgae_negative": [],
"passage_negative_random": []
}]
"""
dataset_synthetic_scores = []
with open(synthetic_train_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
dataset_synthetic_scores.append(data)
dataset_synthetic_queries = {}
with open(synthetic_queries_path, 'r', encoding='utf-8') as f:
for line in f:
json_data = json.loads(line)
dataset_synthetic_queries[json_data['_id']] = json_data
dataset_synthetic_corpus = {}
with open(synthetic_corpus_path, 'r', encoding='utf-8') as f:
for line in f:
json_data = json.loads(line)
dataset_synthetic_corpus[json_data['_id']] = json_data
#create a json which has question, passgae_positive, passgae_negative, passage_negative_random
all_dataset = {}
for data_topic in dataset_synthetic_scores:
query_id = data_topic['query-id']
corpus_id = int(data_topic['corpus-id'])
score = data_topic['score']
if (query_id in dataset_synthetic_queries) and (corpus_id in dataset_synthetic_corpus):
if score == "1":
if query_id in all_dataset:
all_dataset[query_id]['passage_positive'].append(dataset_synthetic_corpus[corpus_id]['text'])
else:
all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'],
'passage_positive': [dataset_synthetic_corpus[corpus_id]['text']],
'passage_negative': [],
'passage_negative_random': []}
else:
if query_id in all_dataset:
all_dataset[query_id]['passage_negative'].append(dataset_synthetic_corpus[corpus_id]['text'])
else:
all_dataset[query_id] = {'question': dataset_synthetic_queries[query_id]['text'],
'passage_positive': [],
'passage_negative': [dataset_synthetic_corpus[corpus_id]['text']],
'passage_negative_random': []}
all_dataset = list(all_dataset.values())
return all_dataset
def load_pquad_dataset():
"""
load pquad dataset from huggingface
output:
[{
"question": "",
"passage_positive": [],
"passage_negative": [],
"passage_negative_random": []
}]
"""
dataset = load_dataset("Gholamreza/pquad", trust_remote_code=True)
all_dataset = []
for data in dataset["train"]:
if len(data["answers"]["text"]) > 0:
all_dataset.append({'question': data['question'], 'passage_positive': [data['context']], 'passage_negative': [], 'passage_negative_random': []})
return all_dataset
def remove_false_negative(dataset, random_negative_sample=False):
"""
remove false negative samples from synthetic dataset
Args:
dataset: list of dicts
Returns:
dataset: list of dicts
"""
if random_negative_sample:
negative_name = "passage_negative_random"
else:
negative_name = "passage_negative"
# calculate passage negative embeddings
negative_count_all = 0
negative_count_removed = 0
len_dataset = len(dataset)
batch_size = 50
for i in tqdm(range(0, len_dataset, batch_size)):
question_list = []
passage_negative_list = []
for id in range(i, min(i + batch_size, len_dataset)):
for passage in dataset[id][negative_name]:
question_list.append(dataset[id]['question'])
passage_negative_list.append(passage)
results = llm_model.remove_false_negative_llm(question_list, passage_negative_list)
negative_count_removed += len([_ for _ in results if _ == "1"])
negative_count_all += len(results)
count = 0
for id in range(i, min(i + batch_size, len_dataset)):
new_negative_list = []
for passage_id in range(len(dataset[id][negative_name])):
if results[count] == "0":
new_negative_list.append(dataset[id][negative_name][passage_id])
count += 1
dataset[id][negative_name] = new_negative_list
print(f"removed {negative_count_removed} false negative samples from {negative_count_all} samples")
print("--------------------------------")
return dataset
def save_dataset(dataset, output_path):
"""
save dataset to json file
Args:
dataset: list of dicts
output_path: path to save dataset
"""
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(dataset, f, ensure_ascii=False, indent=4)
def main(output_path):
# #load synthetic dataset
# print("--------------------------------")
# print("loading synthetic dataset")
# synthetic_train_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/train.jsonl"
# synthetic_corpus_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/corpus.jsonl"
# synthetic_queries_path = "/home/firouzi/embedding_model/data_preprocess_notebook/data/synthetic-persian-qa-retrieval/queries.jsonl"
# synthetic_dataset = load_synthetic_dataset(synthetic_train_path, synthetic_queries_path, synthetic_corpus_path)
# print(f"synthetic dataset loaded : {len(synthetic_dataset)} samples")
# print("--------------------------------")
# #load pquad dataset
# print("loading pquad dataset")
# pquad_dataset = load_pquad_dataset()
# print(f"pquad dataset loaded : {len(pquad_dataset)} samples")
# print("--------------------------------")
# # merge synthetic and pquad dataset
# print("start to merge synthetic and pquad dataset")
# all_dataset = synthetic_dataset + pquad_dataset
# print(f"successfully merged synthetic and pquad dataset")
# print("--------------------------------")
# # removing false negative samples from all dataset
# print("start to remove false negative samples from all dataset")
# all_dataset = remove_false_negative(all_dataset, random_negative_sample=False)
# print(f"successfully removed false negative samples from all dataset")
# print("--------------------------------")
with open("/home/firouzi/embedding_model/data/train.json", "r", encoding="utf-8") as f:
all_dataset = json.load(f)
for i in range(len(all_dataset)):
all_dataset[i]['passage_negative_random'] = []
#generate random negative samples
print("start to generate random negative samples")
all_dataset = generate_random_negative_sample(all_dataset)
print(f"successfully generated random negative samples")
print("--------------------------------")
# removing random false negative samples from all dataset
print("start to remove random false negative samples from all dataset")
all_dataset = remove_false_negative(all_dataset, random_negative_sample=True)
print(f"successfully removed random false negative samples from all dataset")
print("--------------------------------")
# save dataset
print("start to save dataset")
save_dataset(all_dataset, output_path)
print(f"successfully saved dataset")
print("--------------------------------")
if __name__ == "__main__":
"""
preprocess dataset for training
pipelines:
load synthetic dataset from local jsonl files
load pquad dataset from huggingface
remove false negative samples from synthetic dataset
remove false negative samples from pquad dataset
merge synthetic and pquad dataset
generate random negative samples
save dataset to json file
python preprocess_v1.py --output_path /home/firouzi/embedding_model/data/train.json
"""
parser = argparse.ArgumentParser()
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
output_path = args.output_path
main(output_path)