embedding_model/data_preprocess/generate_random_negative_sample.py
2025-11-08 14:25:39 +00:00

95 lines
3.9 KiB
Python

import random
from tqdm import tqdm
import numpy as np
from data_preprocess.text_embedder import TextEmbedder
THRESHOLD_MULTIPLY = 0.9
RANDOM_NEGATIVE_COUNT = 25
batch_size = 100
text_embedder = TextEmbedder()
def generate_random_negative_sample(all_dataset):
"""
generate random negative sample from dataset
Args:
dataset: list of dicts
Returns:
dataset: list of dicts
"""
len_dataset = len(all_dataset)
all_dataset_embeddings = [{'question_embedding': "", 'passage_positive_embedding': [], 'passage_negative_embeddings': []} for _ in range(len_dataset)]
print("calculate question embeddings")
# calculate question embeddings
for i in tqdm(range(0, len_dataset, batch_size)):
question_list = []
for id in range(i, min(i + batch_size, len_dataset)):
question_list.append(all_dataset[id]['question'])
question_embeddings = text_embedder.embed_texts(question_list)
count = 0
for id in range(i, min(i + batch_size, len_dataset)):
all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count]
count += 1
print("calculate passage positive embeddings")
# calculate passage positive embeddings
for i in tqdm(range(0, len_dataset, batch_size)):
passage_positive_list = []
for id in range(i, min(i + batch_size, len_dataset)):
for passage in all_dataset[id]['passage_positive']:
passage_positive_list.append(passage)
passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list)
count = 0
for id in range(i, min(i + batch_size, len_dataset)):
for passage_id in range(len(all_dataset[id]['passage_positive'])):
all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count])
count += 1
print("calculate passage negative embeddings")
# calculate passage negative embeddings
for i in tqdm(range(0, len_dataset, batch_size)):
passage_negative_list = []
for id in range(i, min(i + batch_size, len_dataset)):
for passage in all_dataset[id]['passage_negative']:
passage_negative_list.append(passage)
passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list)
count = 0
for id in range(i, min(i + batch_size, len_dataset)):
for passage_id in range(len(all_dataset[id]['passage_negative'])):
all_dataset_embeddings[id]['passage_negative_embeddings'].append(passage_negative_embeddings[count])
count += 1
print("getting random negative passages")
for id in tqdm(range(len_dataset)):
question_embeddings = all_dataset_embeddings[id]['question_embedding']
passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0]
score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings)
while len(all_dataset[id]['passage_negative_random']) < RANDOM_NEGATIVE_COUNT:
random_id = random.randint(0, len_dataset - 1)
if random_id != id:
all_passages_embedding = all_dataset_embeddings[random_id]['passage_negative_embeddings'] + all_dataset_embeddings[random_id]['passage_positive_embedding']
all_passages = all_dataset[random_id]['passage_negative'] + all_dataset[random_id]['passage_positive']
random_passage_id = random.randint(0, len(all_passages) - 1)
random_passage_embeddings = all_passages_embedding[random_passage_id]
score_question_random_passage = np.dot(question_embeddings, random_passage_embeddings)
if score_question_random_passage < THRESHOLD_MULTIPLY * score_question_passage_positive:
all_dataset[id]['passage_negative_random'].append(all_passages[random_passage_id])
return all_dataset