57 lines
2.4 KiB
Python
57 lines
2.4 KiB
Python
from swift.llm import ResponsePreprocessor, DatasetMeta, register_dataset, SubsetDataset, load_dataset
|
|
from typing import Dict, Any
|
|
import os
|
|
|
|
|
|
class CustomPreprocessor(ResponsePreprocessor):
|
|
# def __init__(self, *, columns = None, **kwargs):
|
|
# super().__init__(columns=columns, **kwargs)
|
|
# self.num_all_negative = 0
|
|
def get_detailed_instruct(self, task_description: str, query: str) -> str:
|
|
return f'Instruct: {task_description}\nQuery:{query}'
|
|
|
|
def add_template(self, text):
|
|
task = 'Given a web search query, retrieve relevant passages that answer the query'
|
|
return self.get_detailed_instruct(task, text)
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
|
query = self.add_template(row["query"])
|
|
passage_positive = row["passage_positive"]
|
|
passage_negative = row["passage_negative"]
|
|
passage_negative_random = row["passage_negative_random"]
|
|
|
|
all_neg = passage_negative + passage_negative_random
|
|
all_neg = list(set(all_neg))
|
|
# self.num_all_negative += len(all_neg)
|
|
|
|
row = {
|
|
# 'query': [{'role': 'user', 'content': query, 'loss': None}],
|
|
'query': query,
|
|
'positive_messages': [
|
|
[{'role': 'user', 'content': passage_positive[i]}] for i in range(len(passage_positive))
|
|
],
|
|
'negative_messages': [
|
|
[{'role': 'user', 'content': all_neg[i]}] for i in range(len(all_neg))
|
|
],
|
|
# 'label': 1.0
|
|
}
|
|
if len(row["negative_messages"]) == 0:
|
|
del row["negative_messages"]
|
|
return super().preprocess(row)
|
|
|
|
|
|
register_dataset(
|
|
DatasetMeta(
|
|
dataset_path=os.path.dirname(__file__) + '/generated_250000_general.jsonl',
|
|
dataset_name="generated_250000_general",
|
|
# subsets=[SubsetDataset('train', split=['train']), SubsetDataset('test', split=['test'])],
|
|
preprocess_func=CustomPreprocessor(),
|
|
))
|
|
|
|
if __name__ == '__main__':
|
|
# load_dataset returns train_dataset and val_dataset based on `split_dataset_ratio`
|
|
# Here, since we didn't pass `split_dataset_ratio` (defaults to 0), we take the first one (index 0)
|
|
dataset = load_dataset('generated_250000_general')[0]
|
|
test_dataset = load_dataset('swift/financial_classification:test')[0]
|
|
print(f'dataset[0]: {dataset[0]}')
|
|
print(f'test_dataset[0]: {test_dataset[0]}') |