From 1296c151ee77917ef14edae6898bcaaaabab7f1e Mon Sep 17 00:00:00 2001 From: hediehloo Date: Wed, 10 Dec 2025 11:05:58 +0000 Subject: [PATCH] add MAX_NUM_THREAD_1 --- src/pipline.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/src/pipline.py b/src/pipline.py index 16224cd..16ad78f 100644 --- a/src/pipline.py +++ b/src/pipline.py @@ -27,18 +27,18 @@ class Pipline: self.file_path = os.path.dirname(__file__) load_dotenv() - worker_configs = self.load_worker_configs() + self.worker_configs = self.load_worker_configs() self.lock = threading.Lock() self.num_handling_request = [] self.configuration = [] self.query_generator = [] - for i in range(len(worker_configs)): + for i in range(len(self.worker_configs)): configuration = Configuration() - configuration.init_persona(worker_configs[i]) + configuration.init_persona(self.worker_configs[i]) self.configuration += [configuration] - self.query_generator = [QueryGenerator(worker_configs[i])] - self.num_handling_request = [0] + self.query_generator += [QueryGenerator(self.worker_configs[i])] + self.num_handling_request += [0] def load_worker_configs(self): @@ -49,6 +49,7 @@ class Pipline: conf["OPENAI_BASE_URL"] = os.environ["OPENAI_BASE_URL_" + str(i)] conf["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY_" + str(i)] conf["OPENAI_MODEL"] = os.environ["OPENAI_MODEL_" + str(i)] + conf["MAX_NUM_THREAD"] = int(os.environ["MAX_NUM_THREAD_" + str(i)]) worker_configs += [conf] except: continue @@ -127,8 +128,15 @@ class Pipline: def exec_function(self, passage): with self.lock: - selected_worker = self.num_handling_request.index(min(self.num_handling_request)) - self.num_handling_request[selected_worker] += 1 + for i in range(len(self.worker_configs)): + if self.num_handling_request[i] < self.worker_configs[i]["MAX_NUM_THREAD"]: + self.num_handling_request[i] += 1 + selected_worker = i + break + else: + selected_worker = self.num_handling_request.index(min(self.num_handling_request)) + self.num_handling_request[selected_worker] += 1 + try: @@ -138,7 +146,7 @@ class Pipline: one_data["document"] = passage one_data["query"] = generated_data["query"] except Exception as e: - one_data = {"passage": passage, "error": traceback.format_exc()} + one_data = {"passage": passage, "error": traceback.format_exc(), "selected_worker": selected_worker} with self.lock: self.num_handling_request[selected_worker] -= 1 @@ -202,14 +210,14 @@ class Pipline: def run(self, save_path = None): num_data = 250000 num_part_data = 25000 - num_threads = 10 - - # num_data = 10 - # num_part_data = 10 - # num_threads = 10 - - # data = self.load_blogs_data() + num_threads = sum([self.worker_configs[i]["MAX_NUM_THREAD"] for i in range(len(self.worker_configs))]) data = self.load_religous_data() + + # num_data = 25 + # num_part_data = 25 + # num_threads = 10 + # data = self.load_blogs_data() + random.shuffle(data) data = data[:num_data] chunk_data = self.pre_process(data)