diff --git a/src/configuration.py b/src/configuration.py index ab8477c..10506c5 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -134,11 +134,11 @@ Ensure to generate only the JSON output with content in English. return config_prompt - def init_persona(self): + def init_persona(self, worker_config): self.index = faiss.read_index(self.file_path + "/../data/faiss.index") self.all_persona = self.load_all_persona() - client = OpenAI(base_url=os.environ["OPENAI_BASE_URL"] ,api_key=os.environ["OPENAI_API_KEY"]) - self.openai_responder = OpenAIResponder(client=client, model=os.environ["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0) + client = OpenAI(base_url=worker_config["OPENAI_BASE_URL"] ,api_key=worker_config["OPENAI_API_KEY"]) + self.openai_responder = OpenAIResponder(client=client, model=worker_config["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0) def get_persona(self, passage): diff --git a/src/openai_responder.py b/src/openai_responder.py index abae37c..8bb5c7a 100644 --- a/src/openai_responder.py +++ b/src/openai_responder.py @@ -35,7 +35,7 @@ class OpenAIResponder: def get_body_to_request(self, messages, temperature): - body = {"model": self.model, "messages": messages,"max_tokens": 8000} + body = {"model": self.model, "messages": messages,"max_tokens": 1000} if temperature != None: body["temperature"] = temperature return body diff --git a/src/pipline.py b/src/pipline.py index 3d25963..c9d78ec 100644 --- a/src/pipline.py +++ b/src/pipline.py @@ -6,6 +6,9 @@ import random import tqdm import pandas as pd import traceback +import threading +from dotenv import load_dotenv + def import_lib(path, file_name, package_name): file_path = path + "/" + file_name + ".py" @@ -22,17 +25,47 @@ ParallelRequester = import_lib(os.path.dirname(__file__) , "parallel_requester", class Pipline: def __init__(self): self.file_path = os.path.dirname(__file__) - self.configuration = Configuration() - self.configuration.init_persona() - self.query_generator = QueryGenerator() + load_dotenv() + + worker_configs = self.load_worker_configs() + + self.lock = threading.Lock() + self.num_handling_request = [] + self.configuration = [] + self.query_generator = [] + for i in range(len(worker_configs)): + configuration = Configuration() + configuration.init_persona(worker_configs[i]) + self.configuration += [configuration] + self.query_generator = [QueryGenerator(worker_configs[i])] + self.num_handling_request = [0] + + + def load_worker_configs(self): + worker_configs = [] + for i in range(100): + try: + conf = {} + conf["OPENAI_BASE_URL"] = os.environ["OPENAI_BASE_URL_" + str(i)] + conf["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY_" + str(i)] + conf["OPENAI_MODEL"] = os.environ["OPENAI_MODEL_" + str(i)] + worker_configs += [conf] + except: + continue + return worker_configs - def load_data(self): + def load_blogs_data(self): df = pd.read_csv(self.file_path + "/../data/persian_blog/blogs.csv") rows = df.values.tolist() rows = [rows[i][0] for i in range(len(rows))] return rows + def load_religous_data(self): + with open(self.file_path + "/../data/religous_data/train_religous.json", "r") as f: + data = json.load(f) + return data + def get_new_path(self): path = self.file_path + "/../data/generated" @@ -90,29 +123,25 @@ class Pipline: with open(json_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) - - def get_a_data(self): - with self.lock: - if self.data_idx < len(self.data): - data = self.data[self.data_idx] - data_idx = self.data_idx - else: - data = None - data_idx = None - self.data_idx += 1 - return data, data_idx - def exec_function(self, passage): + with self.lock: + selected_worker = self.num_handling_request.index(min(self.num_handling_request)) + self.num_handling_request[selected_worker] += 1 + + try: - config = self.configuration.run(passage) - generated_data = self.query_generator.run(passage, config) + config = self.configuration[selected_worker].run(passage) + generated_data = self.query_generator[selected_worker].run(passage, config) one_data = config.copy() one_data["document"] = passage one_data["query"] = generated_data["query"] except Exception as e: one_data = {"passage": passage, "error": traceback.format_exc()} + + with self.lock: + self.num_handling_request[selected_worker] -= 1 return one_data @@ -171,12 +200,20 @@ class Pipline: def run(self, save_path = None): - data = self.load_data() - chunk_data = self.pre_process(data) - num_data = 250000 num_part_data = 25000 - num_threads = 5 + num_threads = 10 + + # num_data = 10 + # num_part_data = 10 + # num_threads = 10 + + # data = self.load_blogs_data() + data = self.load_religous_data() + random.shuffle(data) + data = data[:num_data] + chunk_data = self.pre_process(data) + if save_path == None: save_path = self.get_new_path() diff --git a/src/query_generator.py b/src/query_generator.py index 1a50080..ef17c03 100644 --- a/src/query_generator.py +++ b/src/query_generator.py @@ -21,10 +21,10 @@ OpenAIResponder = import_lib(os.path.dirname(__file__) , "openai_responder", "Op class QueryGenerator: - def __init__(self): - client = OpenAI(base_url=os.environ["OPENAI_BASE_URL"] ,api_key=os.environ["OPENAI_API_KEY"]) + def __init__(self, worker_config): + client = OpenAI(base_url=worker_config["OPENAI_BASE_URL"] ,api_key=worker_config["OPENAI_API_KEY"]) - self.openai_responder = OpenAIResponder(client=client, model=os.environ["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0) + self.openai_responder = OpenAIResponder(client=client, model=worker_config["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0) def get_prompt(self, passage, character, corpus_language, queries_language, difficulty, length, language, question_type): example = {