diff --git a/src/configuration.py b/src/configuration.py index 97cb2a1..5d69d91 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -52,6 +52,25 @@ class Configuration: response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600) return response.json() + + + def get_bge_embedding(self, sentece, prompt_name): + embedding_url = "https://bge.chatllm.aiengines.ir/v1/embeddings" + headers = {"accept": "application/json"} + headers["Content-Type"] = "application/json" + headers["Authorization"] = f"Bearer {os.environ['EMBEDDING_PASS']}" + + + data = {} + data["model"] = "BAAI/bge-m3" + data["input"] = sentece + data["normalize"] = True + + response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600) + res = response.json() + + final_res = [res["data"][i]["embedding"] for i in range(len(sentece))] + return final_res def embedding_persona(self): @@ -64,7 +83,7 @@ class Configuration: for i in tqdm.trange(0, len(all_persona), batch_size): start_idx = i stop_idx = min(len(all_persona), start_idx+batch_size) - all_embeddings += self.get_embedding(all_persona[start_idx:stop_idx], prompt_name) + all_embeddings += self.get_bge_embedding(all_persona[start_idx:stop_idx], prompt_name) xb = numpy.array(all_embeddings).astype('float32') index = faiss.IndexFlatL2(len(all_embeddings[0])) @@ -123,7 +142,7 @@ Ensure to generate only the JSON output with content in English. def get_persona(self, passage): - query_embedding = self.get_embedding(passage, "query") + query_embedding = self.get_bge_embedding([passage], "query") query = numpy.array(query_embedding, dtype='float32') diff --git a/src/pipline.py.py b/src/pipline.py.py index efe65af..3d25963 100644 --- a/src/pipline.py.py +++ b/src/pipline.py.py @@ -32,15 +32,39 @@ class Pipline: rows = df.values.tolist() rows = [rows[i][0] for i in range(len(rows))] return rows + - def save_dataset(self, data): + def get_new_path(self): path = self.file_path + "/../data/generated" if not os.path.exists(path): os.makedirs(path) - files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] + folders = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))] - pattern = r"^v(\d+)_dataset\.json$" + pattern = r"^v(\d+)$" + + all_numbers = [] + + for f in folders: + match = re.match(pattern, f) + if match: + num = int(match.group(1)) + all_numbers.append(num) + + if all_numbers: + number = max(all_numbers) + 1 + else: + number = 1 + + path = os.path.join(path, "v" + str(number)) + if not os.path.exists(path): + os.makedirs(path) + return path + + def get_json_path(self, save_path): + files = [f for f in os.listdir(save_path) if os.path.isfile(os.path.join(save_path, f))] + + pattern = r"^part_(\d+)_dataset\.json$" all_numbers = [] @@ -54,8 +78,17 @@ class Pipline: number = max(all_numbers) + 1 else: number = 1 + + json_path = os.path.join(save_path, "part_" + str(number) + "_dataset.json") + return json_path - with open(path + "/v" + str(number) + "_dataset.json", "w", encoding="utf-8") as f: + + def save_dataset(self, data, save_path): + + + json_path = self.get_json_path(save_path) + + with open(json_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) @@ -88,7 +121,10 @@ class Pipline: for i in range(start_idx, len(sentences)): if len(one_passage) + len(sentences[i]) > selected_lenth and len(one_passage) > 0: return one_passage, i - one_passage += sentences[i] + if one_passage == "": + one_passage += sentences[i] + else: + one_passage += "." + sentences[i] return one_passage, len(sentences) @@ -127,17 +163,28 @@ class Pipline: return chunk_data - def run(self): + def run_one_part(self, chunk_data, save_path, num_threads): + parallel_requester = ParallelRequester() + dataset = parallel_requester.run(chunk_data, self.exec_function, num_threads) + + self.save_dataset(dataset, save_path) + + + def run(self, save_path = None): data = self.load_data() chunk_data = self.pre_process(data) - num_data = 25000 + num_data = 250000 + num_part_data = 25000 num_threads = 5 - parallel_requester = ParallelRequester() - dataset = parallel_requester.run(chunk_data[0:num_data], self.exec_function, num_threads) - - self.save_dataset(dataset) + if save_path == None: + save_path = self.get_new_path() + + for i in range(0, num_data, num_part_data): + start_idx = i + stop_idx = min(i+num_part_data, num_data) + self.run_one_part(chunk_data[start_idx:stop_idx], save_path, num_threads) def main():