add MAX_NUM_THREAD_1

This commit is contained in:
hediehloo 2025-12-10 11:05:58 +00:00
parent 10e01837e6
commit 1296c151ee

View File

@ -27,18 +27,18 @@ class Pipline:
self.file_path = os.path.dirname(__file__) self.file_path = os.path.dirname(__file__)
load_dotenv() load_dotenv()
worker_configs = self.load_worker_configs() self.worker_configs = self.load_worker_configs()
self.lock = threading.Lock() self.lock = threading.Lock()
self.num_handling_request = [] self.num_handling_request = []
self.configuration = [] self.configuration = []
self.query_generator = [] self.query_generator = []
for i in range(len(worker_configs)): for i in range(len(self.worker_configs)):
configuration = Configuration() configuration = Configuration()
configuration.init_persona(worker_configs[i]) configuration.init_persona(self.worker_configs[i])
self.configuration += [configuration] self.configuration += [configuration]
self.query_generator = [QueryGenerator(worker_configs[i])] self.query_generator += [QueryGenerator(self.worker_configs[i])]
self.num_handling_request = [0] self.num_handling_request += [0]
def load_worker_configs(self): def load_worker_configs(self):
@ -49,6 +49,7 @@ class Pipline:
conf["OPENAI_BASE_URL"] = os.environ["OPENAI_BASE_URL_" + str(i)] conf["OPENAI_BASE_URL"] = os.environ["OPENAI_BASE_URL_" + str(i)]
conf["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY_" + str(i)] conf["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY_" + str(i)]
conf["OPENAI_MODEL"] = os.environ["OPENAI_MODEL_" + str(i)] conf["OPENAI_MODEL"] = os.environ["OPENAI_MODEL_" + str(i)]
conf["MAX_NUM_THREAD"] = int(os.environ["MAX_NUM_THREAD_" + str(i)])
worker_configs += [conf] worker_configs += [conf]
except: except:
continue continue
@ -127,8 +128,15 @@ class Pipline:
def exec_function(self, passage): def exec_function(self, passage):
with self.lock: with self.lock:
selected_worker = self.num_handling_request.index(min(self.num_handling_request)) for i in range(len(self.worker_configs)):
self.num_handling_request[selected_worker] += 1 if self.num_handling_request[i] < self.worker_configs[i]["MAX_NUM_THREAD"]:
self.num_handling_request[i] += 1
selected_worker = i
break
else:
selected_worker = self.num_handling_request.index(min(self.num_handling_request))
self.num_handling_request[selected_worker] += 1
try: try:
@ -138,7 +146,7 @@ class Pipline:
one_data["document"] = passage one_data["document"] = passage
one_data["query"] = generated_data["query"] one_data["query"] = generated_data["query"]
except Exception as e: except Exception as e:
one_data = {"passage": passage, "error": traceback.format_exc()} one_data = {"passage": passage, "error": traceback.format_exc(), "selected_worker": selected_worker}
with self.lock: with self.lock:
self.num_handling_request[selected_worker] -= 1 self.num_handling_request[selected_worker] -= 1
@ -202,14 +210,14 @@ class Pipline:
def run(self, save_path = None): def run(self, save_path = None):
num_data = 250000 num_data = 250000
num_part_data = 25000 num_part_data = 25000
num_threads = 10 num_threads = sum([self.worker_configs[i]["MAX_NUM_THREAD"] for i in range(len(self.worker_configs))])
# num_data = 10
# num_part_data = 10
# num_threads = 10
# data = self.load_blogs_data()
data = self.load_religous_data() data = self.load_religous_data()
# num_data = 25
# num_part_data = 25
# num_threads = 10
# data = self.load_blogs_data()
random.shuffle(data) random.shuffle(data)
data = data[:num_data] data = data[:num_data]
chunk_data = self.pre_process(data) chunk_data = self.pre_process(data)