From 59831896f1f1f44b51c438f43ee6e64859e3a419 Mon Sep 17 00:00:00 2001 From: "a.hediehloo" Date: Sat, 27 Dec 2025 06:49:38 +0000 Subject: [PATCH] evaluate_with_religous_50000 --- .gitignore | 2 + .../convert_to_jsonl.py | 14 ++ .../generated_250000_religous_hn/generated.py | 59 +++++++ evaluation/evaluate_with_religous_50000.py | 150 ++++++++++++++++++ train/qwen/a.sh | 4 +- train/qwen/merge_model.py | 4 +- train/qwen/slerp_merge.py | 1 - 7 files changed, 229 insertions(+), 5 deletions(-) create mode 100644 data/dataset/generated_250000_religous_hn/convert_to_jsonl.py create mode 100644 data/dataset/generated_250000_religous_hn/generated.py create mode 100644 evaluation/evaluate_with_religous_50000.py diff --git a/.gitignore b/.gitignore index 9fe8795..9b49d6e 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ data/dataset/generated_250000_religous/__pycache__ data/dataset/my_local_dataset/__pycache__ data/dataset/v11_dataset_hn/__pycache__ data/dataset/v11_generated/__pycache__ +data/dataset/generated_250000_religous_hn/__pycache__ +data/dataset/generated_250000_religous_hn/250_religous_hn.jsonl diff --git a/data/dataset/generated_250000_religous_hn/convert_to_jsonl.py b/data/dataset/generated_250000_religous_hn/convert_to_jsonl.py new file mode 100644 index 0000000..e72dc87 --- /dev/null +++ b/data/dataset/generated_250000_religous_hn/convert_to_jsonl.py @@ -0,0 +1,14 @@ +import json +import os + +file_path = os.path.dirname(__file__) +input_file = file_path + "/250_religous_hn.json" +output_file = file_path + "/250_religous_hn.jsonl" + +with open(input_file, "r", encoding="utf-8") as f_in, open(output_file, "w", encoding="utf-8") as f_out: + data = json.load(f_in) # لیست رکوردها + for record in data: + json_line = json.dumps(record, ensure_ascii=False) + f_out.write(json_line + "\n") + +print(f"Converted {input_file} to {output_file}") \ No newline at end of file diff --git a/data/dataset/generated_250000_religous_hn/generated.py b/data/dataset/generated_250000_religous_hn/generated.py new file mode 100644 index 0000000..57d8f19 --- /dev/null +++ b/data/dataset/generated_250000_religous_hn/generated.py @@ -0,0 +1,59 @@ +from swift.llm import ResponsePreprocessor, DatasetMeta, register_dataset, SubsetDataset, load_dataset +from typing import Dict, Any +import os + + +class CustomPreprocessor(ResponsePreprocessor): + # def __init__(self, *, columns = None, **kwargs): + # super().__init__(columns=columns, **kwargs) + # self.num_all_negative = 0 + def get_detailed_instruct(self, task_description: str, query: str) -> str: + return f'Instruct: {task_description}\nQuery:{query}' + + def add_template(self, text): + task = 'Given a web search query, retrieve relevant passages that answer the query' + return self.get_detailed_instruct(task, text) + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + query = self.add_template(row["query"]) + passage_positive = row["passage_positive"] + passage_negative = row["passage_negative"] + passage_negative_random = row["passage_negative_random"] + passage_negative_random_all = row["passage_negative_random_all"] + + # all_neg = passage_negative + passage_negative_random + passage_negative_random_all + all_neg = passage_negative_random + all_neg = list(set(all_neg)) + # self.num_all_negative += len(all_neg) + + row = { + # 'query': [{'role': 'user', 'content': query, 'loss': None}], + 'query': query, + 'positive_messages': [ + [{'role': 'user', 'content': passage_positive[i]}] for i in range(len(passage_positive)) + ], + 'negative_messages': [ + [{'role': 'user', 'content': all_neg[i]}] for i in range(len(all_neg)) + ], + # 'label': 1.0 + } + if len(row["negative_messages"]) == 0: + del row["negative_messages"] + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + dataset_path=os.path.dirname(__file__) + '/250_religous_hn.jsonl', + dataset_name="generated_250000_religous_hn", + # subsets=[SubsetDataset('train', split=['train']), SubsetDataset('test', split=['test'])], + preprocess_func=CustomPreprocessor(), + )) + +if __name__ == '__main__': + # load_dataset returns train_dataset and val_dataset based on `split_dataset_ratio` + # Here, since we didn't pass `split_dataset_ratio` (defaults to 0), we take the first one (index 0) + dataset = load_dataset('generated_250000_religous_hn')[0] + test_dataset = load_dataset('swift/financial_classification:test')[0] + print(f'dataset[0]: {dataset[0]}') + print(f'test_dataset[0]: {test_dataset[0]}') \ No newline at end of file diff --git a/evaluation/evaluate_with_religous_50000.py b/evaluation/evaluate_with_religous_50000.py new file mode 100644 index 0000000..5451609 --- /dev/null +++ b/evaluation/evaluate_with_religous_50000.py @@ -0,0 +1,150 @@ +import mteb +import numpy as np +import requests +import tqdm +from torch.utils.data import DataLoader +from mteb.encoder_interface import PromptType +from typing import Any +# from mteb.abstasks.task_metadata import TaskMetadata +# from mteb.models.models_protocols import EncoderProtocol +import json +import os +from sentence_transformers import SentenceTransformer +from datasets import load_dataset +from datasets.config import HF_DATASETS_CACHE +from huggingface_hub.utils import get_session +import numpy +import faiss +import numpy as np + + + +class CustomModel: + def __init__(self, model): + self.session = requests.Session() + self.model = model + + + def get_simplexity_query2vec_results(self, sentences, embedding_url, model, template): + params = {} + params["model"] = model + params["template"] = template + headers = {"accept": "application/json"} + data = {} + + if len(sentences) < 2000: + my_range = range + else: + my_range = tqdm.trange + + batch_size = 1024 + vec = [] + for i in my_range(0, len(sentences), batch_size): + start_idx = i + stop_idx = min(i+batch_size, len(sentences)) + data["queries"] = sentences[start_idx:stop_idx] + response = self.session.post(embedding_url, headers=headers, params=params, data=json.dumps(data), timeout=600) + new_vec = response.json() + vec += new_vec + return vec + + + def encode( + self, + sentences: list[str], + task_name: str, + prompt_type: PromptType | None = None, + **kwargs, + ) -> np.ndarray: + + embedding_url = "http://127.0.0.1:5015/embedding" + + if prompt_type == None: + template = "document" + elif prompt_type == PromptType.query: + template = "query" + elif prompt_type == PromptType.document: + template = "document" + else: + raise Exception("Error: prompt_type") + + all_embeddings = [] + # all_texts = [] + # for batch in inputs: + # all_texts += batch["text"] + # embeddings = self.get_simplexity_query2vec_results(batch["text"], embedding_url, model, template) + + # all_embeddings += embeddings + all_embeddings = self.get_simplexity_query2vec_results(sentences, embedding_url, self.model, template) + + + return numpy.array(all_embeddings) + + +def build_faiss_index(embeddings: np.ndarray) -> faiss.Index: + dim = embeddings.shape[1] + index = faiss.IndexFlatIP(dim) # Inner Product = Cosine (بعد از نرمال‌سازی) + index.add(embeddings) + return index + + +def recall_at_k(index, query_embeddings, ground_truth, all_docs, k=10): + """ + ground_truth: لیستی از ایندکس جواب درست برای هر کوئری + """ + scores, retrieved = index.search(query_embeddings, k) + + hits = 0 + for i, gt in enumerate(ground_truth): + if gt in [all_docs[retrieved[i][j]] for j in range(len(retrieved[i]))]: + hits += 1 + + return hits / len(ground_truth) + +def evaluate(): + model_name = "Qwen3-Embedding-0.6B" + # model_name = "KaLM-embedding-multilingual-mini-instruct-v2.5" + # model_name = "KaLM-Embedding-Gemma3-12B-2511" + # model_name = "llama-embed-nemotron-8b" + # model_name = "embeddinggemma-300m" + model = CustomModel(model_name) + + file_path = os.path.dirname(__file__) + + with open(file_path + "/../data/dataset/religous_test_50000/test_religous.json", "r") as f: + data = json.load(f) + + + + all_docs = [data[i]["passage_positive"][j] for i in range(len(data)) for j in range(len(data[i]["passage_positive"]))] + all_queries = [data[i]["question"] for i in range(len(data))] + + + num_docs = len(all_docs) + num_test = len(all_docs) + + doc_embedding = model.encode(all_docs[0:num_docs], None, prompt_type="document") + index = build_faiss_index(doc_embedding) + + + + + query_embeddings = model.encode(all_queries[0:num_test], None, prompt_type="query") + for k in [1, 5, 10]: + ground_truth = [data[i]["passage_positive"][0] for i in range(num_test)] + r = recall_at_k(index, query_embeddings, ground_truth, all_docs, k) + print(f"Recall@{k}: {r:.6f}") + + # finetune: + # Recall@1: 0.401740 + # Recall@5: 0.614240 + # Recall@10: 0.683000 + + +def main(): + # get_results() + evaluate() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train/qwen/a.sh b/train/qwen/a.sh index 2b3c156..fa149a8 100644 --- a/train/qwen/a.sh +++ b/train/qwen/a.sh @@ -19,8 +19,8 @@ swift sft \ --lora_alpha 32 \ --target_modules all-linear \ --max_length 2048 \ - --dataset generated_250000_religous \ - --custom_register_path $(pwd)/../../data/dataset/generated_250000_religous/generated.py \ + --dataset generated_250000_religous_hn \ + --custom_register_path $(pwd)/../../data/dataset/generated_250000_religous_hn/generated.py \ --split_dataset_ratio 0.005 \ --eval_strategy steps \ --output_dir output \ diff --git a/train/qwen/merge_model.py b/train/qwen/merge_model.py index b0ef7cc..9a62d2d 100644 --- a/train/qwen/merge_model.py +++ b/train/qwen/merge_model.py @@ -33,8 +33,8 @@ def main(): file_path = os.path.dirname(__file__) base_model_path = file_path + "/../../data/models/Qwen3-Embedding-0.6B/model" - peft_model_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint" - save_path = file_path + "/output/v23-20251214-111804/merged_checkpoint-slerp" + peft_model_path = file_path + "/output/v28-20251223-054407/checkpoint-3707" + save_path = file_path + "/output/v28-20251223-054407/merged-checkpoint-3707" merge(base_model_path, peft_model_path, save_path) items = ["1_Pooling", "config_sentence_transformers.json", "merges.txt", "modules.json", "README.md", "tokenizer_config.json", "tokenizer.json", diff --git a/train/qwen/slerp_merge.py b/train/qwen/slerp_merge.py index 987a17c..d23cee2 100644 --- a/train/qwen/slerp_merge.py +++ b/train/qwen/slerp_merge.py @@ -94,7 +94,6 @@ def main(): peft_model_path = [ file_path + "/output/v23-20251214-111804/checkpoint-3632", file_path + "/output/v23-20251214-111804/checkpoint-3000", - file_path + "/output/v23-20251214-111804/checkpoint-2000", ] save_path = file_path + "/output/v23-20251214-111804/slerp-checkpoint"