Query-Doc-Generator/src/parallel_requester.py
2025-11-30 14:35:25 +00:00

54 lines
1.5 KiB
Python

import threading
from typing import Callable, Any
from tqdm import tqdm
class ParallelRequester:
def __init__(self):
self.lock = threading.Lock()
def get_a_data(self):
with self.lock:
if self.data_idx < len(self.data):
data = self.data[self.data_idx]
data_idx = self.data_idx
self.data_idx += 1
else:
data = None
data_idx = None
return data, data_idx
def thread_function(self, exec_function: Callable[[Any], Any]):
while True:
data, data_idx = self.get_a_data()
if data == None:
return
self.all_res[data_idx] = exec_function(data)
self.pbar.update(1)
def run(self, data, exec_function, num_threads):
self.data_idx = 0
self.all_res = {}
self.data = data
self.pbar = tqdm(total=len(data), desc="Processing", unit="item")
allthreads = []
for thread_idx in range(num_threads):
allthreads += [threading.Thread(target=self.thread_function, args=(exec_function,))]
for thread_idx in range(num_threads):
allthreads[thread_idx].start()
for thread_idx in range(num_threads):
allthreads[thread_idx].join()
all_res = [self.all_res[i] for i in range(len(self.all_res))]
del self.all_res
del self.data
return all_res