198 lines
5.7 KiB
Python
198 lines
5.7 KiB
Python
import json
|
|
import os
|
|
import importlib
|
|
import re
|
|
import random
|
|
import tqdm
|
|
import pandas as pd
|
|
import traceback
|
|
|
|
def import_lib(path, file_name, package_name):
|
|
file_path = path + "/" + file_name + ".py"
|
|
spec = importlib.util.spec_from_file_location(file_name, file_path)
|
|
imported_file = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(imported_file)
|
|
return getattr(imported_file, package_name)
|
|
|
|
|
|
Configuration = import_lib(os.path.dirname(__file__) , "configuration", "Configuration")
|
|
QueryGenerator = import_lib(os.path.dirname(__file__) , "query_generator", "QueryGenerator")
|
|
ParallelRequester = import_lib(os.path.dirname(__file__) , "parallel_requester", "ParallelRequester")
|
|
|
|
class Pipline:
|
|
def __init__(self):
|
|
self.file_path = os.path.dirname(__file__)
|
|
self.configuration = Configuration()
|
|
self.configuration.init_persona()
|
|
self.query_generator = QueryGenerator()
|
|
|
|
|
|
def load_data(self):
|
|
df = pd.read_csv(self.file_path + "/../data/persian_blog/blogs.csv")
|
|
rows = df.values.tolist()
|
|
rows = [rows[i][0] for i in range(len(rows))]
|
|
return rows
|
|
|
|
|
|
def get_new_path(self):
|
|
path = self.file_path + "/../data/generated"
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
folders = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))]
|
|
|
|
pattern = r"^v(\d+)$"
|
|
|
|
all_numbers = []
|
|
|
|
for f in folders:
|
|
match = re.match(pattern, f)
|
|
if match:
|
|
num = int(match.group(1))
|
|
all_numbers.append(num)
|
|
|
|
if all_numbers:
|
|
number = max(all_numbers) + 1
|
|
else:
|
|
number = 1
|
|
|
|
path = os.path.join(path, "v" + str(number))
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
return path
|
|
|
|
def get_json_path(self, save_path):
|
|
files = [f for f in os.listdir(save_path) if os.path.isfile(os.path.join(save_path, f))]
|
|
|
|
pattern = r"^part_(\d+)_dataset\.json$"
|
|
|
|
all_numbers = []
|
|
|
|
for f in files:
|
|
match = re.match(pattern, f)
|
|
if match:
|
|
num = int(match.group(1))
|
|
all_numbers.append(num)
|
|
|
|
if all_numbers:
|
|
number = max(all_numbers) + 1
|
|
else:
|
|
number = 1
|
|
|
|
json_path = os.path.join(save_path, "part_" + str(number) + "_dataset.json")
|
|
return json_path
|
|
|
|
|
|
def save_dataset(self, data, save_path):
|
|
|
|
|
|
json_path = self.get_json_path(save_path)
|
|
|
|
with open(json_path, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def get_a_data(self):
|
|
with self.lock:
|
|
if self.data_idx < len(self.data):
|
|
data = self.data[self.data_idx]
|
|
data_idx = self.data_idx
|
|
else:
|
|
data = None
|
|
data_idx = None
|
|
self.data_idx += 1
|
|
return data, data_idx
|
|
|
|
|
|
def exec_function(self, passage):
|
|
try:
|
|
config = self.configuration.run(passage)
|
|
generated_data = self.query_generator.run(passage, config)
|
|
one_data = config.copy()
|
|
one_data["document"] = passage
|
|
one_data["query"] = generated_data["query"]
|
|
except Exception as e:
|
|
one_data = {"passage": passage, "error": traceback.format_exc()}
|
|
return one_data
|
|
|
|
|
|
def make_a_passage(self, selected_lenth, sentences, start_idx):
|
|
one_passage = ""
|
|
for i in range(start_idx, len(sentences)):
|
|
if len(one_passage) + len(sentences[i]) > selected_lenth and len(one_passage) > 0:
|
|
return one_passage, i
|
|
if one_passage == "":
|
|
one_passage += sentences[i]
|
|
else:
|
|
one_passage += "." + sentences[i]
|
|
return one_passage, len(sentences)
|
|
|
|
|
|
def chunk_data(self, passage):
|
|
max_length = 3000
|
|
min_length = 30
|
|
|
|
if len(passage) < max_length:
|
|
return [passage]
|
|
|
|
sentences = passage.split(".")
|
|
|
|
all_passages = []
|
|
start_idx = 0
|
|
stop_idx = 0
|
|
while True:
|
|
selected_lenth = random.choice([50, 100, 200, 300, 500, 800, 1300, 2000, 3000])
|
|
start_idx = stop_idx
|
|
one_passage, stop_idx = self.make_a_passage(selected_lenth, sentences, start_idx)
|
|
|
|
if len(one_passage) > min_length:
|
|
all_passages += [one_passage]
|
|
|
|
if stop_idx == len(sentences):
|
|
break
|
|
|
|
return all_passages
|
|
|
|
|
|
|
|
def pre_process(self, data):
|
|
chunk_data = []
|
|
for i in tqdm.trange(len(data)):
|
|
chunk_data += self.chunk_data(data[i])
|
|
random.shuffle(chunk_data)
|
|
return chunk_data
|
|
|
|
|
|
def run_one_part(self, chunk_data, save_path, num_threads):
|
|
parallel_requester = ParallelRequester()
|
|
dataset = parallel_requester.run(chunk_data, self.exec_function, num_threads)
|
|
|
|
self.save_dataset(dataset, save_path)
|
|
|
|
|
|
def run(self, save_path = None):
|
|
data = self.load_data()
|
|
chunk_data = self.pre_process(data)
|
|
|
|
num_data = 250000
|
|
num_part_data = 25000
|
|
num_threads = 5
|
|
|
|
if save_path == None:
|
|
save_path = self.get_new_path()
|
|
|
|
for i in range(0, num_data, num_part_data):
|
|
start_idx = i
|
|
stop_idx = min(i+num_part_data, num_data)
|
|
self.run_one_part(chunk_data[start_idx:stop_idx], save_path, num_threads)
|
|
|
|
|
|
def main():
|
|
random.seed(42)
|
|
pipline = Pipline()
|
|
pipline.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|