119 lines
4.9 KiB
Python
119 lines
4.9 KiB
Python
import random
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import faiss
|
|
|
|
from data_preprocess.text_embedder import TextEmbedder
|
|
|
|
|
|
THRESHOLD_MULTIPLY = 0.95
|
|
RANDOM_NEGATIVE_COUNT = 6
|
|
batch_size = 1000
|
|
|
|
text_embedder = TextEmbedder()
|
|
|
|
|
|
def generate_random_negative_sample(all_dataset, corpus_list=[]):
|
|
"""
|
|
generate random negative sample from dataset
|
|
Args:
|
|
dataset: list of dicts
|
|
Returns:
|
|
dataset: list of dicts
|
|
"""
|
|
len_dataset = len(all_dataset)
|
|
all_dataset_embeddings = [{'question_embedding': "", 'passage_positive_embedding': []} for _ in range(len_dataset)]
|
|
|
|
all_embeddings = []
|
|
all_texts = []
|
|
|
|
print("calculate question embeddings")
|
|
# calculate question embeddings
|
|
for i in tqdm(range(0, len_dataset, batch_size)):
|
|
|
|
question_list = []
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
question_list.append(all_dataset[id]['question'])
|
|
|
|
question_embeddings = text_embedder.embed_texts(question_list, do_preprocess=False, convert_to_numpy=False)
|
|
|
|
count = 0
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count]
|
|
count += 1
|
|
|
|
|
|
print("calculate passage positive embeddings")
|
|
# calculate passage positive embeddings
|
|
for i in tqdm(range(0, len_dataset, batch_size)):
|
|
|
|
passage_positive_list = []
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage in all_dataset[id]['passage_positive']:
|
|
passage_positive_list.append(passage)
|
|
|
|
passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list, do_preprocess=False, convert_to_numpy=False)
|
|
|
|
count = 0
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage_id in range(len(all_dataset[id]['passage_positive'])):
|
|
all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count])
|
|
all_embeddings.append(passage_positive_embeddings[count])
|
|
all_texts.append(all_dataset[id]['passage_positive'][passage_id])
|
|
count += 1
|
|
|
|
print("calculate passage negative embeddings")
|
|
# calculate passage negative embeddings
|
|
for i in tqdm(range(0, len_dataset, batch_size)):
|
|
|
|
passage_negative_list = []
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage in all_dataset[id]['passage_negative']:
|
|
passage_negative_list.append(passage)
|
|
|
|
passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list, do_preprocess=False, convert_to_numpy=False)
|
|
|
|
count = 0
|
|
for id in range(i, min(i + batch_size, len_dataset)):
|
|
for passage_id in range(len(all_dataset[id]['passage_negative'])):
|
|
all_embeddings.append(passage_negative_embeddings[count])
|
|
all_texts.append(all_dataset[id]['passage_negative'][passage_id])
|
|
count += 1
|
|
|
|
print("calculate corpus embeddings")
|
|
# calculate corpus embeddings
|
|
for i in tqdm(range(0, len(corpus_list), batch_size)):
|
|
corpus_embeddings = text_embedder.embed_texts(corpus_list[i:i+batch_size], do_preprocess=False, convert_to_numpy=False)
|
|
all_embeddings.extend(corpus_embeddings)
|
|
all_texts.extend(corpus_list[i:i+batch_size])
|
|
|
|
############ Create FAISS index ############
|
|
all_embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
dim = all_embeddings.shape[1]
|
|
# index = faiss.IndexFlatIP(dim)
|
|
index = faiss.IndexHNSWFlat(dim, 32, faiss.METRIC_INNER_PRODUCT)
|
|
faiss.normalize_L2(all_embeddings)
|
|
index.add(all_embeddings)
|
|
|
|
############ Get random hard negative passages ############
|
|
print("getting random negative passages")
|
|
for id in tqdm(range(len_dataset)):
|
|
not_valid_passages = all_dataset[id]['passage_negative'] + all_dataset[id]['passage_positive']
|
|
question_embeddings = all_dataset_embeddings[id]['question_embedding']
|
|
question_embeddings_normalized = np.array([question_embeddings], dtype=np.float32)
|
|
faiss.normalize_L2(question_embeddings_normalized)
|
|
# passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0]
|
|
|
|
# score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings)
|
|
|
|
num_retrieved = 15
|
|
vector_scores, vector_ids = index.search(question_embeddings_normalized, num_retrieved)
|
|
for vector_score, vector_id in zip(vector_scores[0], vector_ids[0]):
|
|
if (all_texts[vector_id] not in not_valid_passages):# and (vector_score < THRESHOLD_MULTIPLY * score_question_passage_positive):
|
|
all_dataset[id]['passage_negative_random'].append(all_texts[vector_id])
|
|
not_valid_passages.append(all_texts[vector_id])
|
|
|
|
if len(all_dataset[id]['passage_negative_random']) >= RANDOM_NEGATIVE_COUNT:
|
|
break
|
|
|
|
return all_dataset |