53 lines
2.1 KiB
Python
53 lines
2.1 KiB
Python
from swift.llm import ResponsePreprocessor, DatasetMeta, register_dataset, SubsetDataset, load_dataset
|
|
from typing import Dict, Any
|
|
import os
|
|
|
|
|
|
class CustomPreprocessor(ResponsePreprocessor):
|
|
def get_detailed_instruct(self, task_description: str, query: str) -> str:
|
|
return f'Instruct: {task_description}\nQuery:{query}'
|
|
|
|
def add_template(self, text):
|
|
task = 'Given a web search query, retrieve relevant passages that answer the query'
|
|
return self.get_detailed_instruct(task, text)
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
|
query = self.add_template(row["query"])
|
|
passage_positive = [row["document"]]
|
|
passage_negative = []
|
|
passage_negative_random = []
|
|
|
|
all_neg = passage_negative + passage_negative_random
|
|
all_neg = list(set(all_neg))
|
|
|
|
row = {
|
|
# 'query': [{'role': 'user', 'content': query, 'loss': None}],
|
|
'query': query,
|
|
'positive_messages': [
|
|
[{'role': 'user', 'content': passage_positive[i]}] for i in range(len(passage_positive))
|
|
],
|
|
'negative_messages': [
|
|
[{'role': 'user', 'content': all_neg[i]}] for i in range(len(all_neg))
|
|
],
|
|
# 'label': 1.0
|
|
}
|
|
if len(row["negative_messages"]) == 0:
|
|
del row["negative_messages"]
|
|
return super().preprocess(row)
|
|
|
|
|
|
register_dataset(
|
|
DatasetMeta(
|
|
dataset_path=os.path.dirname(__file__) + '/v11_dataset.json',
|
|
dataset_name="v11_generated_dataset",
|
|
# subsets=[SubsetDataset('train', split=['train']), SubsetDataset('test', split=['test'])],
|
|
preprocess_func=CustomPreprocessor(),
|
|
))
|
|
|
|
if __name__ == '__main__':
|
|
# load_dataset returns train_dataset and val_dataset based on `split_dataset_ratio`
|
|
# Here, since we didn't pass `split_dataset_ratio` (defaults to 0), we take the first one (index 0)
|
|
dataset = load_dataset('v11_generated_dataset')[0]
|
|
test_dataset = load_dataset('swift/financial_classification:test')[0]
|
|
print(f'dataset[0]: {dataset[0]}')
|
|
print(f'test_dataset[0]: {test_dataset[0]}') |