embedding_model/data_preprocess/preprocess_v2.py
2025-11-16 15:30:36 +00:00

160 lines
5.1 KiB
Python

import argparse
from datasets import load_dataset
import json
from tqdm import tqdm
from data_preprocess.remove_false_negative_model import LLMModel
from data_preprocess.generate_random_negative_sample import generate_random_negative_sample
llm_model = LLMModel()
def load_msmarco_dataset():
"""
load pquad dataset from huggingface
output:
[{
"question": "",
"passage_positive": [],
"passage_negative": [],
"passage_negative_random": []
}]
"""
print("start loading msmarco dataset")
name = "MCINext/msmarco-fa"
dataset_qrel = load_dataset(name)["train"]
print("start loading corpus")
dataset_corpus_list = load_dataset(name,data_files="corpus.jsonl")["train"]
dataset_corpus = {}
for data in dataset_corpus_list:
dataset_corpus[str(data["_id"])] = data["text"]
print("start loading queries")
dataset_queries_list = load_dataset(name,data_files="queries.jsonl")["train"]
dataset_queries = {}
for data in dataset_queries_list:
dataset_queries[str(data["_id"])] = data["text"]
dataset = []
print("start creating dataset")
for data in tqdm(dataset_qrel):
if data["query-id"] in dataset_queries and data["corpus-id"] in dataset_corpus:
dataset.append({
"question": dataset_queries[data["query-id"]],
"passage_positive": [dataset_corpus[data["corpus-id"]]],
"passage_negative": [],
"passage_negative_random": [],
})
print(f"length of dataset: {len(dataset)}")
print("--------------------------------")
return dataset, list(dataset_corpus.values())
def remove_false_negative(dataset, random_negative_sample=False):
"""
remove false negative samples from synthetic dataset
Args:
dataset: list of dicts
Returns:
dataset: list of dicts
"""
if random_negative_sample:
negative_name = "passage_negative_random"
else:
negative_name = "passage_negative"
# calculate passage negative embeddings
negative_count_all = 0
negative_count_removed = 0
len_dataset = len(dataset)
batch_size = 50
for i in tqdm(range(0, len_dataset, batch_size)):
question_list = []
passage_negative_list = []
for id in range(i, min(i + batch_size, len_dataset)):
for passage in dataset[id][negative_name]:
question_list.append(dataset[id]['question'])
passage_negative_list.append(passage)
results = llm_model.remove_false_negative_llm(question_list, passage_negative_list)
negative_count_removed += len([_ for _ in results if _ == "1"])
negative_count_all += len(results)
count = 0
for id in range(i, min(i + batch_size, len_dataset)):
new_negative_list = []
for passage_id in range(len(dataset[id][negative_name])):
if results[count] == "0":
new_negative_list.append(dataset[id][negative_name][passage_id])
count += 1
dataset[id][negative_name] = new_negative_list
print(f"removed {negative_count_removed} false negative samples from {negative_count_all} samples")
print("--------------------------------")
return dataset
def save_dataset(dataset, output_path):
"""
save dataset to json file
Args:
dataset: list of dicts
output_path: path to save dataset
"""
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(dataset, f, ensure_ascii=False, indent=4)
def main(output_path):
#load msmarco dataset
print("--------------------------------")
all_dataset, corpus_list = load_msmarco_dataset()
print(f"msmarco dataset loaded : {len(all_dataset)} samples")
print("--------------------------------")
#generate random negative samples
print("start to generate random negative samples")
all_dataset = generate_random_negative_sample(all_dataset, corpus_list)
print(f"successfully generated random negative samples")
print("--------------------------------")
# removing random false negative samples from all dataset
print("start to remove random false negative samples from all dataset")
all_dataset = remove_false_negative(all_dataset, random_negative_sample=True)
print(f"successfully removed random false negative samples from all dataset")
print("--------------------------------")
# save dataset
print("start to save dataset")
save_dataset(all_dataset, output_path)
print(f"successfully saved dataset")
print("--------------------------------")
if __name__ == "__main__":
"""
preprocess dataset for training
pipelines:
load msmarco dataset from huggingface
generate random negative samples
save dataset to json file
python preprocess_v2.py --output_path /home/firouzi/embedding_model/data/v2/msmarco_train.json
"""
parser = argparse.ArgumentParser()
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
output_path = args.output_path
main(output_path)