add MAX_NUM_THREAD_1
This commit is contained in:
parent
10e01837e6
commit
1296c151ee
@ -27,18 +27,18 @@ class Pipline:
|
||||
self.file_path = os.path.dirname(__file__)
|
||||
load_dotenv()
|
||||
|
||||
worker_configs = self.load_worker_configs()
|
||||
self.worker_configs = self.load_worker_configs()
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.num_handling_request = []
|
||||
self.configuration = []
|
||||
self.query_generator = []
|
||||
for i in range(len(worker_configs)):
|
||||
for i in range(len(self.worker_configs)):
|
||||
configuration = Configuration()
|
||||
configuration.init_persona(worker_configs[i])
|
||||
configuration.init_persona(self.worker_configs[i])
|
||||
self.configuration += [configuration]
|
||||
self.query_generator = [QueryGenerator(worker_configs[i])]
|
||||
self.num_handling_request = [0]
|
||||
self.query_generator += [QueryGenerator(self.worker_configs[i])]
|
||||
self.num_handling_request += [0]
|
||||
|
||||
|
||||
def load_worker_configs(self):
|
||||
@ -49,6 +49,7 @@ class Pipline:
|
||||
conf["OPENAI_BASE_URL"] = os.environ["OPENAI_BASE_URL_" + str(i)]
|
||||
conf["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY_" + str(i)]
|
||||
conf["OPENAI_MODEL"] = os.environ["OPENAI_MODEL_" + str(i)]
|
||||
conf["MAX_NUM_THREAD"] = int(os.environ["MAX_NUM_THREAD_" + str(i)])
|
||||
worker_configs += [conf]
|
||||
except:
|
||||
continue
|
||||
@ -127,8 +128,15 @@ class Pipline:
|
||||
|
||||
def exec_function(self, passage):
|
||||
with self.lock:
|
||||
selected_worker = self.num_handling_request.index(min(self.num_handling_request))
|
||||
self.num_handling_request[selected_worker] += 1
|
||||
for i in range(len(self.worker_configs)):
|
||||
if self.num_handling_request[i] < self.worker_configs[i]["MAX_NUM_THREAD"]:
|
||||
self.num_handling_request[i] += 1
|
||||
selected_worker = i
|
||||
break
|
||||
else:
|
||||
selected_worker = self.num_handling_request.index(min(self.num_handling_request))
|
||||
self.num_handling_request[selected_worker] += 1
|
||||
|
||||
|
||||
|
||||
try:
|
||||
@ -138,7 +146,7 @@ class Pipline:
|
||||
one_data["document"] = passage
|
||||
one_data["query"] = generated_data["query"]
|
||||
except Exception as e:
|
||||
one_data = {"passage": passage, "error": traceback.format_exc()}
|
||||
one_data = {"passage": passage, "error": traceback.format_exc(), "selected_worker": selected_worker}
|
||||
|
||||
with self.lock:
|
||||
self.num_handling_request[selected_worker] -= 1
|
||||
@ -202,14 +210,14 @@ class Pipline:
|
||||
def run(self, save_path = None):
|
||||
num_data = 250000
|
||||
num_part_data = 25000
|
||||
num_threads = 10
|
||||
|
||||
# num_data = 10
|
||||
# num_part_data = 10
|
||||
# num_threads = 10
|
||||
|
||||
# data = self.load_blogs_data()
|
||||
num_threads = sum([self.worker_configs[i]["MAX_NUM_THREAD"] for i in range(len(self.worker_configs))])
|
||||
data = self.load_religous_data()
|
||||
|
||||
# num_data = 25
|
||||
# num_part_data = 25
|
||||
# num_threads = 10
|
||||
# data = self.load_blogs_data()
|
||||
|
||||
random.shuffle(data)
|
||||
data = data[:num_data]
|
||||
chunk_data = self.pre_process(data)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user