use BAAI/bge-base-en-v1.5
Browse files
app.py
CHANGED
|
@@ -5,8 +5,8 @@ import torch
|
|
| 5 |
from transformers import AutoModel, AutoTokenizer
|
| 6 |
import meilisearch
|
| 7 |
|
| 8 |
-
tokenizer = AutoTokenizer.from_pretrained('
|
| 9 |
-
model = AutoModel.from_pretrained('
|
| 10 |
model.eval()
|
| 11 |
|
| 12 |
cuda_available = torch.cuda.is_available()
|
|
@@ -23,16 +23,17 @@ def search_embeddings(query_text):
|
|
| 23 |
# step1: tokenizer the query
|
| 24 |
with torch.no_grad():
|
| 25 |
# Compute token embeddings
|
| 26 |
-
|
|
|
|
| 27 |
# normalize embeddings
|
| 28 |
-
|
| 29 |
-
|
| 30 |
elapsed_time_embedding = time.time() - start_time_embedding
|
| 31 |
|
| 32 |
# step2: search meilisearch
|
| 33 |
start_time_meilisearch = time.time()
|
| 34 |
response = meilisearch_index.search(
|
| 35 |
-
"", opt_params={"vector":
|
| 36 |
)
|
| 37 |
elapsed_time_meilisearch = time.time() - start_time_meilisearch
|
| 38 |
hits = response["hits"]
|
|
|
|
| 5 |
from transformers import AutoModel, AutoTokenizer
|
| 6 |
import meilisearch
|
| 7 |
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-base-en-v1.5')
|
| 9 |
+
model = AutoModel.from_pretrained('BAAI/bge-base-en-v1.5')
|
| 10 |
model.eval()
|
| 11 |
|
| 12 |
cuda_available = torch.cuda.is_available()
|
|
|
|
| 23 |
# step1: tokenizer the query
|
| 24 |
with torch.no_grad():
|
| 25 |
# Compute token embeddings
|
| 26 |
+
model_output = model(**query_tokens)
|
| 27 |
+
sentence_embeddings = model_output[0][:, 0]
|
| 28 |
# normalize embeddings
|
| 29 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
| 30 |
+
sentence_embeddings_list = sentence_embeddings[0].tolist()
|
| 31 |
elapsed_time_embedding = time.time() - start_time_embedding
|
| 32 |
|
| 33 |
# step2: search meilisearch
|
| 34 |
start_time_meilisearch = time.time()
|
| 35 |
response = meilisearch_index.search(
|
| 36 |
+
"", opt_params={"vector": sentence_embeddings_list, "hybrid": {"semanticRatio": 1.0}, "limit": 5, "attributesToRetrieve": ["text", "source", "library"]}
|
| 37 |
)
|
| 38 |
elapsed_time_meilisearch = time.time() - start_time_meilisearch
|
| 39 |
hits = response["hits"]
|