2025-12-21 12:09:32 +00:00

57 lines
2.3 KiB
Python

from swift.llm import ResponsePreprocessor, DatasetMeta, register_dataset, SubsetDataset, load_dataset
from typing import Dict, Any
import os
class CustomPreprocessor(ResponsePreprocessor):
# def __init__(self, *, columns = None, **kwargs):
# super().__init__(columns=columns, **kwargs)
# self.num_all_negative = 0
def get_detailed_instruct(self, task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
def add_template(self, text):
task = 'Given a web search query, retrieve relevant passages that answer the query'
return self.get_detailed_instruct(task, text)
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
query = self.add_template(row["query"])
passage_positive = row["passage_positive"]
passage_negative = row["passage_negative"]
passage_negative_random = row["passage_negative_random"]
all_neg = passage_negative + passage_negative_random
all_neg = list(set(all_neg))
# self.num_all_negative += len(all_neg)
row = {
# 'query': [{'role': 'user', 'content': query, 'loss': None}],
'query': query,
'positive_messages': [
[{'role': 'user', 'content': passage_positive[i]}] for i in range(len(passage_positive))
],
'negative_messages': [
[{'role': 'user', 'content': all_neg[i]}] for i in range(len(all_neg))
],
# 'label': 1.0
}
if len(row["negative_messages"]) == 0:
del row["negative_messages"]
return super().preprocess(row)
register_dataset(
DatasetMeta(
dataset_path=os.path.dirname(__file__) + '/v11_dataset_hn.json',
dataset_name="v11_dataset_hn",
# subsets=[SubsetDataset('train', split=['train']), SubsetDataset('test', split=['test'])],
preprocess_func=CustomPreprocessor(),
))
if __name__ == '__main__':
# load_dataset returns train_dataset and val_dataset based on `split_dataset_ratio`
# Here, since we didn't pass `split_dataset_ratio` (defaults to 0), we take the first one (index 0)
dataset = load_dataset('v11_dataset_hn')[0]
test_dataset = load_dataset('swift/financial_classification:test')[0]
print(f'dataset[0]: {dataset[0]}')
print(f'test_dataset[0]: {test_dataset[0]}')