24 lines
680 B
Python
24 lines
680 B
Python
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
# model_id = "google/embeddinggemma-300M"
|
|
model_id = "my-embedding-gemma/checkpoint-15"
|
|
model = SentenceTransformer(model_id).to(device=device)
|
|
|
|
|
|
def get_scores(query, document):
|
|
query_embedding = model.encode_query(query)
|
|
doc_embedding = model.encode_document(document)
|
|
|
|
# Calculate the embedding similarities
|
|
similarities = model.similarity(query_embedding, doc_embedding)
|
|
|
|
print(similarities)
|
|
|
|
query = "I want to start a tax-free installment investment, what should I do?"
|
|
documents = "Opening a NISA Account"
|
|
|
|
get_scores(query, documents)
|