embedding_model/data_preprocess/generate_random_negative_sample.py
2025-11-09 13:44:28 +00:00

110 lines
4.3 KiB
Python

import random
from tqdm import tqdm
import numpy as np
import faiss
from data_preprocess.text_embedder import TextEmbedder
THRESHOLD_MULTIPLY = 0.95
RANDOM_NEGATIVE_COUNT = 6
batch_size = 100
text_embedder = TextEmbedder()
def generate_random_negative_sample(all_dataset):
"""
generate random negative sample from dataset
Args:
dataset: list of dicts
Returns:
dataset: list of dicts
"""
len_dataset = len(all_dataset)
all_dataset_embeddings = [{'question_embedding': "", 'passage_positive_embedding': []} for _ in range(len_dataset)]
all_embeddings = []
all_texts = []
print("calculate question embeddings")
# calculate question embeddings
for i in tqdm(range(0, len_dataset, batch_size)):
question_list = []
for id in range(i, min(i + batch_size, len_dataset)):
question_list.append(all_dataset[id]['question'])
question_embeddings = text_embedder.embed_texts(question_list)
count = 0
for id in range(i, min(i + batch_size, len_dataset)):
all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count]
count += 1
print("calculate passage positive embeddings")
# calculate passage positive embeddings
for i in tqdm(range(0, len_dataset, batch_size)):
passage_positive_list = []
for id in range(i, min(i + batch_size, len_dataset)):
for passage in all_dataset[id]['passage_positive']:
passage_positive_list.append(passage)
passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list)
count = 0
for id in range(i, min(i + batch_size, len_dataset)):
for passage_id in range(len(all_dataset[id]['passage_positive'])):
all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count])
all_embeddings.append(passage_positive_embeddings[count])
all_texts.append(all_dataset[id]['passage_positive'][passage_id])
count += 1
print("calculate passage negative embeddings")
# calculate passage negative embeddings
for i in tqdm(range(0, len_dataset, batch_size)):
passage_negative_list = []
for id in range(i, min(i + batch_size, len_dataset)):
for passage in all_dataset[id]['passage_negative']:
passage_negative_list.append(passage)
passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list)
count = 0
for id in range(i, min(i + batch_size, len_dataset)):
for passage_id in range(len(all_dataset[id]['passage_negative'])):
all_embeddings.append(passage_negative_embeddings[count])
all_texts.append(all_dataset[id]['passage_negative'][passage_id])
count += 1
############ Create FAISS index ############
all_embeddings = np.array(all_embeddings, dtype=np.float32)
dim = all_embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
faiss.normalize_L2(all_embeddings)
index.add(all_embeddings)
############ Get random hard negative passages ############
print("getting random negative passages")
for id in tqdm(range(len_dataset)):
not_valid_passages = all_dataset[id]['passage_negative'] + all_dataset[id]['passage_positive']
question_embeddings = all_dataset_embeddings[id]['question_embedding']
question_embeddings_normalized = np.array([question_embeddings], dtype=np.float32)
faiss.normalize_L2(question_embeddings_normalized)
passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0]
score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings)
num_retrieved = 30
vector_scores, vector_ids = index.search(question_embeddings_normalized, num_retrieved)
for vector_score, vector_id in zip(vector_scores[0], vector_ids[0]):
if (all_texts[vector_id] not in not_valid_passages) and (vector_score < THRESHOLD_MULTIPLY * score_question_passage_positive):
all_dataset[id]['passage_negative_random'].append(all_texts[vector_id])
if len(all_dataset[id]['passage_negative_random']) >= RANDOM_NEGATIVE_COUNT:
break
return all_dataset