Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -135,6 +135,8 @@ def retrieve_gradio(knowledge_base: str, query: str, topk: int):
|
|
| 135 |
doc_list = [f for f in os.listdir(target_cache_dir) if f.endswith('.npy')]
|
| 136 |
doc_list = sorted(doc_list)
|
| 137 |
doc_reps = [np.load(os.path.join(target_cache_dir, f)) for f in doc_list]
|
|
|
|
|
|
|
| 138 |
|
| 139 |
query_with_instruction = "Represent this query for retrieving relevant document: " + query
|
| 140 |
with torch.no_grad():
|
|
@@ -142,7 +144,6 @@ def retrieve_gradio(knowledge_base: str, query: str, topk: int):
|
|
| 142 |
|
| 143 |
query_md5 = hashlib.md5(query.encode()).hexdigest()
|
| 144 |
|
| 145 |
-
doc_reps_cat = torch.cat([torch.Tensor(i) for i in doc_reps], dim=0)
|
| 146 |
print(f"query_rep_shape: {query_rep.shape}, doc_reps_cat_shape: {doc_reps_cat.shape}")
|
| 147 |
similarities = torch.matmul(query_rep, doc_reps_cat.T)
|
| 148 |
|
|
|
|
| 135 |
doc_list = [f for f in os.listdir(target_cache_dir) if f.endswith('.npy')]
|
| 136 |
doc_list = sorted(doc_list)
|
| 137 |
doc_reps = [np.load(os.path.join(target_cache_dir, f)) for f in doc_list]
|
| 138 |
+
doc_reps_cat = torch.cat([torch.Tensor(i) for i in doc_reps], dim=0)
|
| 139 |
+
doc_reps_cat = torch.cat([i for i in doc_reps_cat], dim=0)
|
| 140 |
|
| 141 |
query_with_instruction = "Represent this query for retrieving relevant document: " + query
|
| 142 |
with torch.no_grad():
|
|
|
|
| 144 |
|
| 145 |
query_md5 = hashlib.md5(query.encode()).hexdigest()
|
| 146 |
|
|
|
|
| 147 |
print(f"query_rep_shape: {query_rep.shape}, doc_reps_cat_shape: {doc_reps_cat.shape}")
|
| 148 |
similarities = torch.matmul(query_rep, doc_reps_cat.T)
|
| 149 |
|