add multiworker

This commit is contained in:
hediehloo 2025-12-10 02:40:08 +00:00
parent efb1dd6771
commit 71ec024a30
4 changed files with 66 additions and 29 deletions

View File

@ -134,11 +134,11 @@ Ensure to generate only the JSON output with content in English.
return config_prompt
def init_persona(self):
def init_persona(self, worker_config):
self.index = faiss.read_index(self.file_path + "/../data/faiss.index")
self.all_persona = self.load_all_persona()
client = OpenAI(base_url=os.environ["OPENAI_BASE_URL"] ,api_key=os.environ["OPENAI_API_KEY"])
self.openai_responder = OpenAIResponder(client=client, model=os.environ["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0)
client = OpenAI(base_url=worker_config["OPENAI_BASE_URL"] ,api_key=worker_config["OPENAI_API_KEY"])
self.openai_responder = OpenAIResponder(client=client, model=worker_config["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0)
def get_persona(self, passage):

View File

@ -35,7 +35,7 @@ class OpenAIResponder:
def get_body_to_request(self, messages, temperature):
body = {"model": self.model, "messages": messages,"max_tokens": 8000}
body = {"model": self.model, "messages": messages,"max_tokens": 1000}
if temperature != None:
body["temperature"] = temperature
return body

View File

@ -6,6 +6,9 @@ import random
import tqdm
import pandas as pd
import traceback
import threading
from dotenv import load_dotenv
def import_lib(path, file_name, package_name):
file_path = path + "/" + file_name + ".py"
@ -22,17 +25,47 @@ ParallelRequester = import_lib(os.path.dirname(__file__) , "parallel_requester",
class Pipline:
def __init__(self):
self.file_path = os.path.dirname(__file__)
self.configuration = Configuration()
self.configuration.init_persona()
self.query_generator = QueryGenerator()
load_dotenv()
worker_configs = self.load_worker_configs()
self.lock = threading.Lock()
self.num_handling_request = []
self.configuration = []
self.query_generator = []
for i in range(len(worker_configs)):
configuration = Configuration()
configuration.init_persona(worker_configs[i])
self.configuration += [configuration]
self.query_generator = [QueryGenerator(worker_configs[i])]
self.num_handling_request = [0]
def load_worker_configs(self):
worker_configs = []
for i in range(100):
try:
conf = {}
conf["OPENAI_BASE_URL"] = os.environ["OPENAI_BASE_URL_" + str(i)]
conf["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY_" + str(i)]
conf["OPENAI_MODEL"] = os.environ["OPENAI_MODEL_" + str(i)]
worker_configs += [conf]
except:
continue
return worker_configs
def load_data(self):
def load_blogs_data(self):
df = pd.read_csv(self.file_path + "/../data/persian_blog/blogs.csv")
rows = df.values.tolist()
rows = [rows[i][0] for i in range(len(rows))]
return rows
def load_religous_data(self):
with open(self.file_path + "/../data/religous_data/train_religous.json", "r") as f:
data = json.load(f)
return data
def get_new_path(self):
path = self.file_path + "/../data/generated"
@ -90,29 +123,25 @@ class Pipline:
with open(json_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def get_a_data(self):
with self.lock:
if self.data_idx < len(self.data):
data = self.data[self.data_idx]
data_idx = self.data_idx
else:
data = None
data_idx = None
self.data_idx += 1
return data, data_idx
def exec_function(self, passage):
with self.lock:
selected_worker = self.num_handling_request.index(min(self.num_handling_request))
self.num_handling_request[selected_worker] += 1
try:
config = self.configuration.run(passage)
generated_data = self.query_generator.run(passage, config)
config = self.configuration[selected_worker].run(passage)
generated_data = self.query_generator[selected_worker].run(passage, config)
one_data = config.copy()
one_data["document"] = passage
one_data["query"] = generated_data["query"]
except Exception as e:
one_data = {"passage": passage, "error": traceback.format_exc()}
with self.lock:
self.num_handling_request[selected_worker] -= 1
return one_data
@ -171,12 +200,20 @@ class Pipline:
def run(self, save_path = None):
data = self.load_data()
chunk_data = self.pre_process(data)
num_data = 250000
num_part_data = 25000
num_threads = 5
num_threads = 10
# num_data = 10
# num_part_data = 10
# num_threads = 10
# data = self.load_blogs_data()
data = self.load_religous_data()
random.shuffle(data)
data = data[:num_data]
chunk_data = self.pre_process(data)
if save_path == None:
save_path = self.get_new_path()

View File

@ -21,10 +21,10 @@ OpenAIResponder = import_lib(os.path.dirname(__file__) , "openai_responder", "Op
class QueryGenerator:
def __init__(self):
client = OpenAI(base_url=os.environ["OPENAI_BASE_URL"] ,api_key=os.environ["OPENAI_API_KEY"])
def __init__(self, worker_config):
client = OpenAI(base_url=worker_config["OPENAI_BASE_URL"] ,api_key=worker_config["OPENAI_API_KEY"])
self.openai_responder = OpenAIResponder(client=client, model=os.environ["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0)
self.openai_responder = OpenAIResponder(client=client, model=worker_config["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0)
def get_prompt(self, passage, character, corpus_language, queries_language, difficulty, length, language, question_type):
example = {