add multiworker
This commit is contained in:
parent
efb1dd6771
commit
71ec024a30
@ -134,11 +134,11 @@ Ensure to generate only the JSON output with content in English.
|
||||
return config_prompt
|
||||
|
||||
|
||||
def init_persona(self):
|
||||
def init_persona(self, worker_config):
|
||||
self.index = faiss.read_index(self.file_path + "/../data/faiss.index")
|
||||
self.all_persona = self.load_all_persona()
|
||||
client = OpenAI(base_url=os.environ["OPENAI_BASE_URL"] ,api_key=os.environ["OPENAI_API_KEY"])
|
||||
self.openai_responder = OpenAIResponder(client=client, model=os.environ["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0)
|
||||
client = OpenAI(base_url=worker_config["OPENAI_BASE_URL"] ,api_key=worker_config["OPENAI_API_KEY"])
|
||||
self.openai_responder = OpenAIResponder(client=client, model=worker_config["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0)
|
||||
|
||||
|
||||
def get_persona(self, passage):
|
||||
|
||||
@ -35,7 +35,7 @@ class OpenAIResponder:
|
||||
|
||||
def get_body_to_request(self, messages, temperature):
|
||||
|
||||
body = {"model": self.model, "messages": messages,"max_tokens": 8000}
|
||||
body = {"model": self.model, "messages": messages,"max_tokens": 1000}
|
||||
if temperature != None:
|
||||
body["temperature"] = temperature
|
||||
return body
|
||||
|
||||
@ -6,6 +6,9 @@ import random
|
||||
import tqdm
|
||||
import pandas as pd
|
||||
import traceback
|
||||
import threading
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
def import_lib(path, file_name, package_name):
|
||||
file_path = path + "/" + file_name + ".py"
|
||||
@ -22,17 +25,47 @@ ParallelRequester = import_lib(os.path.dirname(__file__) , "parallel_requester",
|
||||
class Pipline:
|
||||
def __init__(self):
|
||||
self.file_path = os.path.dirname(__file__)
|
||||
self.configuration = Configuration()
|
||||
self.configuration.init_persona()
|
||||
self.query_generator = QueryGenerator()
|
||||
load_dotenv()
|
||||
|
||||
worker_configs = self.load_worker_configs()
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.num_handling_request = []
|
||||
self.configuration = []
|
||||
self.query_generator = []
|
||||
for i in range(len(worker_configs)):
|
||||
configuration = Configuration()
|
||||
configuration.init_persona(worker_configs[i])
|
||||
self.configuration += [configuration]
|
||||
self.query_generator = [QueryGenerator(worker_configs[i])]
|
||||
self.num_handling_request = [0]
|
||||
|
||||
|
||||
def load_worker_configs(self):
|
||||
worker_configs = []
|
||||
for i in range(100):
|
||||
try:
|
||||
conf = {}
|
||||
conf["OPENAI_BASE_URL"] = os.environ["OPENAI_BASE_URL_" + str(i)]
|
||||
conf["OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY_" + str(i)]
|
||||
conf["OPENAI_MODEL"] = os.environ["OPENAI_MODEL_" + str(i)]
|
||||
worker_configs += [conf]
|
||||
except:
|
||||
continue
|
||||
return worker_configs
|
||||
|
||||
|
||||
def load_data(self):
|
||||
def load_blogs_data(self):
|
||||
df = pd.read_csv(self.file_path + "/../data/persian_blog/blogs.csv")
|
||||
rows = df.values.tolist()
|
||||
rows = [rows[i][0] for i in range(len(rows))]
|
||||
return rows
|
||||
|
||||
def load_religous_data(self):
|
||||
with open(self.file_path + "/../data/religous_data/train_religous.json", "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def get_new_path(self):
|
||||
path = self.file_path + "/../data/generated"
|
||||
@ -90,29 +123,25 @@ class Pipline:
|
||||
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def get_a_data(self):
|
||||
with self.lock:
|
||||
if self.data_idx < len(self.data):
|
||||
data = self.data[self.data_idx]
|
||||
data_idx = self.data_idx
|
||||
else:
|
||||
data = None
|
||||
data_idx = None
|
||||
self.data_idx += 1
|
||||
return data, data_idx
|
||||
|
||||
|
||||
def exec_function(self, passage):
|
||||
with self.lock:
|
||||
selected_worker = self.num_handling_request.index(min(self.num_handling_request))
|
||||
self.num_handling_request[selected_worker] += 1
|
||||
|
||||
|
||||
try:
|
||||
config = self.configuration.run(passage)
|
||||
generated_data = self.query_generator.run(passage, config)
|
||||
config = self.configuration[selected_worker].run(passage)
|
||||
generated_data = self.query_generator[selected_worker].run(passage, config)
|
||||
one_data = config.copy()
|
||||
one_data["document"] = passage
|
||||
one_data["query"] = generated_data["query"]
|
||||
except Exception as e:
|
||||
one_data = {"passage": passage, "error": traceback.format_exc()}
|
||||
|
||||
with self.lock:
|
||||
self.num_handling_request[selected_worker] -= 1
|
||||
return one_data
|
||||
|
||||
|
||||
@ -171,12 +200,20 @@ class Pipline:
|
||||
|
||||
|
||||
def run(self, save_path = None):
|
||||
data = self.load_data()
|
||||
chunk_data = self.pre_process(data)
|
||||
|
||||
num_data = 250000
|
||||
num_part_data = 25000
|
||||
num_threads = 5
|
||||
num_threads = 10
|
||||
|
||||
# num_data = 10
|
||||
# num_part_data = 10
|
||||
# num_threads = 10
|
||||
|
||||
# data = self.load_blogs_data()
|
||||
data = self.load_religous_data()
|
||||
random.shuffle(data)
|
||||
data = data[:num_data]
|
||||
chunk_data = self.pre_process(data)
|
||||
|
||||
|
||||
if save_path == None:
|
||||
save_path = self.get_new_path()
|
||||
|
||||
@ -21,10 +21,10 @@ OpenAIResponder = import_lib(os.path.dirname(__file__) , "openai_responder", "Op
|
||||
|
||||
|
||||
class QueryGenerator:
|
||||
def __init__(self):
|
||||
client = OpenAI(base_url=os.environ["OPENAI_BASE_URL"] ,api_key=os.environ["OPENAI_API_KEY"])
|
||||
def __init__(self, worker_config):
|
||||
client = OpenAI(base_url=worker_config["OPENAI_BASE_URL"] ,api_key=worker_config["OPENAI_API_KEY"])
|
||||
|
||||
self.openai_responder = OpenAIResponder(client=client, model=os.environ["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0)
|
||||
self.openai_responder = OpenAIResponder(client=client, model=worker_config["OPENAI_MODEL"], price_per_1m_input_tokens=0, price_per_1m_output_tokens=0)
|
||||
|
||||
def get_prompt(self, passage, character, corpus_language, queries_language, difficulty, length, language, question_type):
|
||||
example = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user