from swift.llm import ResponsePreprocessor, DatasetMeta, register_dataset, SubsetDataset, load_dataset from typing import Dict, Any import os class CustomPreprocessor(ResponsePreprocessor): # def __init__(self, *, columns = None, **kwargs): # super().__init__(columns=columns, **kwargs) # self.num_all_negative = 0 def get_detailed_instruct(self, task_description: str, query: str) -> str: return f'Instruct: {task_description}\nQuery:{query}' def add_template(self, text): task = 'Given a web search query, retrieve relevant passages that answer the query' return self.get_detailed_instruct(task, text) def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: query = self.add_template(row["query"]) passage_positive = row["passage_positive"] passage_negative = row["passage_negative"] passage_negative_random = row["passage_negative_random"] all_neg = passage_negative + passage_negative_random all_neg = list(set(all_neg)) # self.num_all_negative += len(all_neg) row = { # 'query': [{'role': 'user', 'content': query, 'loss': None}], 'query': query, 'positive_messages': [ [{'role': 'user', 'content': passage_positive[i]}] for i in range(len(passage_positive)) ], 'negative_messages': [ [{'role': 'user', 'content': all_neg[i]}] for i in range(len(all_neg)) ], # 'label': 1.0 } if len(row["negative_messages"]) == 0: del row["negative_messages"] return super().preprocess(row) register_dataset( DatasetMeta( dataset_path=os.path.dirname(__file__) + '/generated_250000_general.jsonl', dataset_name="generated_250000_general", # subsets=[SubsetDataset('train', split=['train']), SubsetDataset('test', split=['test'])], preprocess_func=CustomPreprocessor(), )) if __name__ == '__main__': # load_dataset returns train_dataset and val_dataset based on `split_dataset_ratio` # Here, since we didn't pass `split_dataset_ratio` (defaults to 0), we take the first one (index 0) dataset = load_dataset('generated_250000_general')[0] test_dataset = load_dataset('swift/financial_classification:test')[0] print(f'dataset[0]: {dataset[0]}') print(f'test_dataset[0]: {test_dataset[0]}')