add MAX_NUM_THREAD_1

This commit is contained in:
hediehloo 2025-12-10 11:05:58 +00:00
parent 10e01837e6
commit 1296c151ee

View File

@ -27,18 +27,18 @@ class Pipline:
self.file_path = os.path.dirname(__file__)
load_dotenv()
worker_configs = self.load_worker_configs()
self.worker_configs = self.load_worker_configs()
self.lock = threading.Lock()
self.num_handling_request = []
self.configuration = []
self.query_generator = []
for i in range(len(worker_configs)):
for i in range(len(self.worker_configs)):
configuration = Configuration()
configuration.init_persona(worker_configs[i])
configuration.init_persona(self.worker_configs[i])
self.configuration += [configuration]
self.query_generator = [QueryGenerator(worker_configs[i])]
self.num_handling_request = [0]
self.query_generator += [QueryGenerator(self.worker_configs[i])]
self.num_handling_request += [0]
def load_worker_configs(self):
@ -49,6 +49,7 @@ class Pipline:
conf["OPENAI_BASE_URL"] = os.environ["OPENAI_BASE_URL_" + str(i)]
conf["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY_" + str(i)]
conf["OPENAI_MODEL"] = os.environ["OPENAI_MODEL_" + str(i)]
conf["MAX_NUM_THREAD"] = int(os.environ["MAX_NUM_THREAD_" + str(i)])
worker_configs += [conf]
except:
continue
@ -127,10 +128,17 @@ class Pipline:
def exec_function(self, passage):
with self.lock:
for i in range(len(self.worker_configs)):
if self.num_handling_request[i] < self.worker_configs[i]["MAX_NUM_THREAD"]:
self.num_handling_request[i] += 1
selected_worker = i
break
else:
selected_worker = self.num_handling_request.index(min(self.num_handling_request))
self.num_handling_request[selected_worker] += 1
try:
config = self.configuration[selected_worker].run(passage)
generated_data = self.query_generator[selected_worker].run(passage, config)
@ -138,7 +146,7 @@ class Pipline:
one_data["document"] = passage
one_data["query"] = generated_data["query"]
except Exception as e:
one_data = {"passage": passage, "error": traceback.format_exc()}
one_data = {"passage": passage, "error": traceback.format_exc(), "selected_worker": selected_worker}
with self.lock:
self.num_handling_request[selected_worker] -= 1
@ -202,14 +210,14 @@ class Pipline:
def run(self, save_path = None):
num_data = 250000
num_part_data = 25000
num_threads = 10
# num_data = 10
# num_part_data = 10
# num_threads = 10
# data = self.load_blogs_data()
num_threads = sum([self.worker_configs[i]["MAX_NUM_THREAD"] for i in range(len(self.worker_configs))])
data = self.load_religous_data()
# num_data = 25
# num_part_data = 25
# num_threads = 10
# data = self.load_blogs_data()
random.shuffle(data)
data = data[:num_data]
chunk_data = self.pre_process(data)