import random from tqdm import tqdm import numpy as np from data_preprocess.text_embedder import TextEmbedder THRESHOLD_MULTIPLY = 0.9 RANDOM_NEGATIVE_COUNT = 25 batch_size = 100 text_embedder = TextEmbedder() def generate_random_negative_sample(all_dataset): """ generate random negative sample from dataset Args: dataset: list of dicts Returns: dataset: list of dicts """ len_dataset = len(all_dataset) all_dataset_embeddings = [{'question_embedding': "", 'passage_positive_embedding': [], 'passage_negative_embeddings': []} for _ in range(len_dataset)] print("calculate question embeddings") # calculate question embeddings for i in tqdm(range(0, len_dataset, batch_size)): question_list = [] for id in range(i, min(i + batch_size, len_dataset)): question_list.append(all_dataset[id]['question']) question_embeddings = text_embedder.embed_texts(question_list) count = 0 for id in range(i, min(i + batch_size, len_dataset)): all_dataset_embeddings[id]['question_embedding'] = question_embeddings[count] count += 1 print("calculate passage positive embeddings") # calculate passage positive embeddings for i in tqdm(range(0, len_dataset, batch_size)): passage_positive_list = [] for id in range(i, min(i + batch_size, len_dataset)): for passage in all_dataset[id]['passage_positive']: passage_positive_list.append(passage) passage_positive_embeddings = text_embedder.embed_texts(passage_positive_list) count = 0 for id in range(i, min(i + batch_size, len_dataset)): for passage_id in range(len(all_dataset[id]['passage_positive'])): all_dataset_embeddings[id]['passage_positive_embedding'].append(passage_positive_embeddings[count]) count += 1 print("calculate passage negative embeddings") # calculate passage negative embeddings for i in tqdm(range(0, len_dataset, batch_size)): passage_negative_list = [] for id in range(i, min(i + batch_size, len_dataset)): for passage in all_dataset[id]['passage_negative']: passage_negative_list.append(passage) passage_negative_embeddings = text_embedder.embed_texts(passage_negative_list) count = 0 for id in range(i, min(i + batch_size, len_dataset)): for passage_id in range(len(all_dataset[id]['passage_negative'])): all_dataset_embeddings[id]['passage_negative_embeddings'].append(passage_negative_embeddings[count]) count += 1 print("getting random negative passages") for id in tqdm(range(len_dataset)): question_embeddings = all_dataset_embeddings[id]['question_embedding'] passage_positive_embeddings = all_dataset_embeddings[id]['passage_positive_embedding'][0] score_question_passage_positive = np.dot(question_embeddings, passage_positive_embeddings) while len(all_dataset[id]['passage_negative_random']) < RANDOM_NEGATIVE_COUNT: random_id = random.randint(0, len_dataset - 1) if random_id != id: all_passages_embedding = all_dataset_embeddings[random_id]['passage_negative_embeddings'] + all_dataset_embeddings[random_id]['passage_positive_embedding'] all_passages = all_dataset[random_id]['passage_negative'] + all_dataset[random_id]['passage_positive'] random_passage_id = random.randint(0, len(all_passages) - 1) random_passage_embeddings = all_passages_embedding[random_passage_id] score_question_random_passage = np.dot(question_embeddings, random_passage_embeddings) if score_question_random_passage < THRESHOLD_MULTIPLY * score_question_passage_positive: all_dataset[id]['passage_negative_random'].append(all_passages[random_passage_id]) return all_dataset