embedding_model/train/gemma/gemma_inference.py
2025-11-11 15:02:47 +00:00

24 lines
680 B
Python

import torch
from sentence_transformers import SentenceTransformer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# model_id = "google/embeddinggemma-300M"
model_id = "my-embedding-gemma/checkpoint-15"
model = SentenceTransformer(model_id).to(device=device)
def get_scores(query, document):
query_embedding = model.encode_query(query)
doc_embedding = model.encode_document(document)
# Calculate the embedding similarities
similarities = model.similarity(query_embedding, doc_embedding)
print(similarities)
query = "I want to start a tax-free installment investment, what should I do?"
documents = "Opening a NISA Account"
get_scores(query, documents)