import random from tqdm import tqdm import numpy as np import faiss from data_preprocess.text_embedder import TextEmbedder THRESHOLD_MULTIPLY = 0.95 RANDOM_NEGATIVE_COUNT = 6 batch_size = 1000 text_embedder = TextEmbedder() def generate_random_negative_sample(all_dataset, corpus_list=[]): """ generate random negative sample from dataset Args: dataset: list of dicts Returns: dataset: list of dicts """ len_dataset = len(all_dataset) all_dataset_embeddings = [{'question_embedding': "", 'passage_positive_embedding': []} for _ in range(len_dataset)] all_embeddings = [] all_texts = [] print("calculate question embeddings") # calculate question embeddings for i in tqdm(range(0, len_dataset, batch_size)): question_list = [] for id in range(i, min(i + batch_size, len_dataset)): question_list.append(all_dataset[id]['question']) question_embeddings = text_embedder.embed_texts(question_list, do_preprocess=False, convert_to_numpy=False) count = 0 for id in range(i, min(i + batch_size, len_dataset)): all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count] count += 1 print("calculate passage positive embeddings") # calculate passage positive embeddings for i in tqdm(range(0, len_dataset, batch_size)): passage_positive_list = [] for id in range(i, min(i + batch_size, len_dataset)): for passage in all_dataset[id]['passage_positive']: passage_positive_list.append(passage) passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list, do_preprocess=False, convert_to_numpy=False) count = 0 for id in range(i, min(i + batch_size, len_dataset)): for passage_id in range(len(all_dataset[id]['passage_positive'])): all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count]) all_embeddings.append(passage_positive_embeddings[count]) all_texts.append(all_dataset[id]['passage_positive'][passage_id]) count += 1 print("calculate passage negative embeddings") # calculate passage negative embeddings for i in tqdm(range(0, len_dataset, batch_size)): passage_negative_list = [] for id in range(i, min(i + batch_size, len_dataset)): for passage in all_dataset[id]['passage_negative']: passage_negative_list.append(passage) passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list, do_preprocess=False, convert_to_numpy=False) count = 0 for id in range(i, min(i + batch_size, len_dataset)): for passage_id in range(len(all_dataset[id]['passage_negative'])): all_embeddings.append(passage_negative_embeddings[count]) all_texts.append(all_dataset[id]['passage_negative'][passage_id]) count += 1 print("calculate corpus embeddings") # calculate corpus embeddings for i in tqdm(range(0, len(corpus_list), batch_size)): corpus_embeddings = text_embedder.embed_texts(corpus_list[i:i+batch_size], do_preprocess=False, convert_to_numpy=False) all_embeddings.extend(corpus_embeddings) all_texts.extend(corpus_list[i:i+batch_size]) ############ Create FAISS index ############ all_embeddings = np.array(all_embeddings, dtype=np.float32) dim = all_embeddings.shape[1] # index = faiss.IndexFlatIP(dim) index = faiss.IndexHNSWFlat(dim, 32, faiss.METRIC_INNER_PRODUCT) faiss.normalize_L2(all_embeddings) index.add(all_embeddings) ############ Get random hard negative passages ############ print("getting random negative passages") for id in tqdm(range(len_dataset)): not_valid_passages = all_dataset[id]['passage_negative'] + all_dataset[id]['passage_positive'] question_embeddings = all_dataset_embeddings[id]['question_embedding'] question_embeddings_normalized = np.array([question_embeddings], dtype=np.float32) faiss.normalize_L2(question_embeddings_normalized) # passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0] # score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings) num_retrieved = 15 vector_scores, vector_ids = index.search(question_embeddings_normalized, num_retrieved) for vector_score, vector_id in zip(vector_scores[0], vector_ids[0]): if (all_texts[vector_id] not in not_valid_passages):# and (vector_score < THRESHOLD_MULTIPLY * score_question_passage_positive): all_dataset[id]['passage_negative_random'].append(all_texts[vector_id]) not_valid_passages.append(all_texts[vector_id]) if len(all_dataset[id]['passage_negative_random']) >= RANDOM_NEGATIVE_COUNT: break return all_dataset