95 lines
3.9 KiB
Python
95 lines
3.9 KiB
Python
import random
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
from data_preprocess.text_embedder import TextEmbedder
|
|
|
|
|
|
THRESHOLD_MULTIPLY = 0.9
|
|
RANDOM_NEGATIVE_COUNT = 25
|
|
batch_size = 100
|
|
|
|
text_embedder = TextEmbedder()
|
|
|
|
|
|
def generate_random_negative_sample(all_dataset):
|
|
"""
|
|
generate random negative sample from dataset
|
|
Args:
|
|
dataset: list of dicts
|
|
Returns:
|
|
dataset: list of dicts
|
|
"""
|
|
len_dataset = len(all_dataset)
|
|
all_dataset_embeddings = [{'question_embedding': "", 'passage_positive_embedding': [], 'passage_negative_embeddings': []} for _ in range(len_dataset)]
|
|
|
|
print("calculate question embeddings")
|
|
# calculate question embeddings
|
|
for i in tqdm(range(0, len_dataset, batch_size)):
|
|
|
|
question_list = []
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
question_list.append(all_dataset[id]['question'])
|
|
|
|
question_embeddings = text_embedder.embed_texts(question_list)
|
|
|
|
count = 0
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count]
|
|
count += 1
|
|
|
|
|
|
print("calculate passage positive embeddings")
|
|
# calculate passage positive embeddings
|
|
for i in tqdm(range(0, len_dataset, batch_size)):
|
|
|
|
passage_positive_list = []
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage in all_dataset[id]['passage_positive']:
|
|
passage_positive_list.append(passage)
|
|
|
|
passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list)
|
|
|
|
count = 0
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage_id in range(len(all_dataset[id]['passage_positive'])):
|
|
all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count])
|
|
count += 1
|
|
|
|
print("calculate passage negative embeddings")
|
|
# calculate passage negative embeddings
|
|
for i in tqdm(range(0, len_dataset, batch_size)):
|
|
|
|
passage_negative_list = []
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage in all_dataset[id]['passage_negative']:
|
|
passage_negative_list.append(passage)
|
|
|
|
passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list)
|
|
|
|
count = 0
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage_id in range(len(all_dataset[id]['passage_negative'])):
|
|
all_dataset_embeddings[id]['passage_negative_embeddings'].append(passage_negative_embeddings[count])
|
|
count += 1
|
|
|
|
print("getting random negative passages")
|
|
for id in tqdm(range(len_dataset)):
|
|
question_embeddings = all_dataset_embeddings[id]['question_embedding']
|
|
passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0]
|
|
|
|
score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings)
|
|
|
|
while len(all_dataset[id]['passage_negative_random']) < RANDOM_NEGATIVE_COUNT:
|
|
random_id = random.randint(0, len_dataset - 1)
|
|
if random_id != id:
|
|
all_passages_embedding = all_dataset_embeddings[random_id]['passage_negative_embeddings'] + all_dataset_embeddings[random_id]['passage_positive_embedding']
|
|
all_passages = all_dataset[random_id]['passage_negative'] + all_dataset[random_id]['passage_positive']
|
|
random_passage_id = random.randint(0, len(all_passages) - 1)
|
|
random_passage_embeddings = all_passages_embedding[random_passage_id]
|
|
score_question_random_passage = np.dot(question_embeddings, random_passage_embeddings)
|
|
|
|
if score_question_random_passage < THRESHOLD_MULTIPLY * score_question_passage_positive:
|
|
all_dataset[id]['passage_negative_random'].append(all_passages[random_passage_id])
|
|
|
|
return all_dataset |