add bge and part save
This commit is contained in:
parent
e6cf02fcee
commit
fe1211a907
@ -52,6 +52,25 @@ class Configuration:
|
|||||||
response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600)
|
response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600)
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def get_bge_embedding(self, sentece, prompt_name):
|
||||||
|
embedding_url = "https://bge.chatllm.aiengines.ir/v1/embeddings"
|
||||||
|
headers = {"accept": "application/json"}
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
headers["Authorization"] = f"Bearer {os.environ['EMBEDDING_PASS']}"
|
||||||
|
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
data["model"] = "BAAI/bge-m3"
|
||||||
|
data["input"] = sentece
|
||||||
|
data["normalize"] = True
|
||||||
|
|
||||||
|
response = self.session.post(embedding_url, headers=headers, data=json.dumps(data), timeout=600)
|
||||||
|
res = response.json()
|
||||||
|
|
||||||
|
final_res = [res["data"][i]["embedding"] for i in range(len(sentece))]
|
||||||
|
return final_res
|
||||||
|
|
||||||
|
|
||||||
def embedding_persona(self):
|
def embedding_persona(self):
|
||||||
@ -64,7 +83,7 @@ class Configuration:
|
|||||||
for i in tqdm.trange(0, len(all_persona), batch_size):
|
for i in tqdm.trange(0, len(all_persona), batch_size):
|
||||||
start_idx = i
|
start_idx = i
|
||||||
stop_idx = min(len(all_persona), start_idx+batch_size)
|
stop_idx = min(len(all_persona), start_idx+batch_size)
|
||||||
all_embeddings += self.get_embedding(all_persona[start_idx:stop_idx], prompt_name)
|
all_embeddings += self.get_bge_embedding(all_persona[start_idx:stop_idx], prompt_name)
|
||||||
|
|
||||||
xb = numpy.array(all_embeddings).astype('float32')
|
xb = numpy.array(all_embeddings).astype('float32')
|
||||||
index = faiss.IndexFlatL2(len(all_embeddings[0]))
|
index = faiss.IndexFlatL2(len(all_embeddings[0]))
|
||||||
@ -123,7 +142,7 @@ Ensure to generate only the JSON output with content in English.
|
|||||||
|
|
||||||
|
|
||||||
def get_persona(self, passage):
|
def get_persona(self, passage):
|
||||||
query_embedding = self.get_embedding(passage, "query")
|
query_embedding = self.get_bge_embedding([passage], "query")
|
||||||
query = numpy.array(query_embedding, dtype='float32')
|
query = numpy.array(query_embedding, dtype='float32')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -32,15 +32,39 @@ class Pipline:
|
|||||||
rows = df.values.tolist()
|
rows = df.values.tolist()
|
||||||
rows = [rows[i][0] for i in range(len(rows))]
|
rows = [rows[i][0] for i in range(len(rows))]
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
|
||||||
def save_dataset(self, data):
|
def get_new_path(self):
|
||||||
path = self.file_path + "/../data/generated"
|
path = self.file_path + "/../data/generated"
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
|
folders = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))]
|
||||||
|
|
||||||
pattern = r"^v(\d+)_dataset\.json$"
|
pattern = r"^v(\d+)$"
|
||||||
|
|
||||||
|
all_numbers = []
|
||||||
|
|
||||||
|
for f in folders:
|
||||||
|
match = re.match(pattern, f)
|
||||||
|
if match:
|
||||||
|
num = int(match.group(1))
|
||||||
|
all_numbers.append(num)
|
||||||
|
|
||||||
|
if all_numbers:
|
||||||
|
number = max(all_numbers) + 1
|
||||||
|
else:
|
||||||
|
number = 1
|
||||||
|
|
||||||
|
path = os.path.join(path, "v" + str(number))
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_json_path(self, save_path):
|
||||||
|
files = [f for f in os.listdir(save_path) if os.path.isfile(os.path.join(save_path, f))]
|
||||||
|
|
||||||
|
pattern = r"^part_(\d+)_dataset\.json$"
|
||||||
|
|
||||||
all_numbers = []
|
all_numbers = []
|
||||||
|
|
||||||
@ -54,8 +78,17 @@ class Pipline:
|
|||||||
number = max(all_numbers) + 1
|
number = max(all_numbers) + 1
|
||||||
else:
|
else:
|
||||||
number = 1
|
number = 1
|
||||||
|
|
||||||
|
json_path = os.path.join(save_path, "part_" + str(number) + "_dataset.json")
|
||||||
|
return json_path
|
||||||
|
|
||||||
with open(path + "/v" + str(number) + "_dataset.json", "w", encoding="utf-8") as f:
|
|
||||||
|
def save_dataset(self, data, save_path):
|
||||||
|
|
||||||
|
|
||||||
|
json_path = self.get_json_path(save_path)
|
||||||
|
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
@ -88,7 +121,10 @@ class Pipline:
|
|||||||
for i in range(start_idx, len(sentences)):
|
for i in range(start_idx, len(sentences)):
|
||||||
if len(one_passage) + len(sentences[i]) > selected_lenth and len(one_passage) > 0:
|
if len(one_passage) + len(sentences[i]) > selected_lenth and len(one_passage) > 0:
|
||||||
return one_passage, i
|
return one_passage, i
|
||||||
one_passage += sentences[i]
|
if one_passage == "":
|
||||||
|
one_passage += sentences[i]
|
||||||
|
else:
|
||||||
|
one_passage += "." + sentences[i]
|
||||||
return one_passage, len(sentences)
|
return one_passage, len(sentences)
|
||||||
|
|
||||||
|
|
||||||
@ -127,17 +163,28 @@ class Pipline:
|
|||||||
return chunk_data
|
return chunk_data
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run_one_part(self, chunk_data, save_path, num_threads):
|
||||||
|
parallel_requester = ParallelRequester()
|
||||||
|
dataset = parallel_requester.run(chunk_data, self.exec_function, num_threads)
|
||||||
|
|
||||||
|
self.save_dataset(dataset, save_path)
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, save_path = None):
|
||||||
data = self.load_data()
|
data = self.load_data()
|
||||||
chunk_data = self.pre_process(data)
|
chunk_data = self.pre_process(data)
|
||||||
|
|
||||||
num_data = 25000
|
num_data = 250000
|
||||||
|
num_part_data = 25000
|
||||||
num_threads = 5
|
num_threads = 5
|
||||||
|
|
||||||
parallel_requester = ParallelRequester()
|
if save_path == None:
|
||||||
dataset = parallel_requester.run(chunk_data[0:num_data], self.exec_function, num_threads)
|
save_path = self.get_new_path()
|
||||||
|
|
||||||
self.save_dataset(dataset)
|
for i in range(0, num_data, num_part_data):
|
||||||
|
start_idx = i
|
||||||
|
stop_idx = min(i+num_part_data, num_data)
|
||||||
|
self.run_one_part(chunk_data[start_idx:stop_idx], save_path, num_threads)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user