54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
import threading
|
|
from typing import Callable, Any
|
|
from tqdm import tqdm
|
|
|
|
class ParallelRequester:
|
|
def __init__(self):
|
|
self.lock = threading.Lock()
|
|
|
|
|
|
def get_a_data(self):
|
|
with self.lock:
|
|
if self.data_idx < len(self.data):
|
|
data = self.data[self.data_idx]
|
|
data_idx = self.data_idx
|
|
self.data_idx += 1
|
|
else:
|
|
data = None
|
|
data_idx = None
|
|
|
|
|
|
return data, data_idx
|
|
|
|
|
|
def thread_function(self, exec_function: Callable[[Any], Any]):
|
|
while True:
|
|
data, data_idx = self.get_a_data()
|
|
if data == None:
|
|
return
|
|
self.all_res[data_idx] = exec_function(data)
|
|
self.pbar.update(1)
|
|
|
|
|
|
def run(self, data, exec_function, num_threads):
|
|
self.data_idx = 0
|
|
self.all_res = {}
|
|
self.data = data
|
|
|
|
self.pbar = tqdm(total=len(data), desc="Processing", unit="item")
|
|
|
|
allthreads = []
|
|
for thread_idx in range(num_threads):
|
|
allthreads += [threading.Thread(target=self.thread_function, args=(exec_function,))]
|
|
for thread_idx in range(num_threads):
|
|
allthreads[thread_idx].start()
|
|
for thread_idx in range(num_threads):
|
|
allthreads[thread_idx].join()
|
|
|
|
all_res = [self.all_res[i] for i in range(len(self.all_res))]
|
|
|
|
del self.all_res
|
|
del self.data
|
|
|
|
return all_res
|