speed up calculate score
Browse files- src/utils/paper_client.py +15 -0
- src/utils/paper_retriever.py +9 -8
src/utils/paper_client.py
CHANGED
|
@@ -79,6 +79,21 @@ class PaperClient:
|
|
| 79 |
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
| 80 |
return None
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
| 83 |
query = f"""
|
| 84 |
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
|
|
|
| 79 |
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
| 80 |
return None
|
| 81 |
|
| 82 |
+
def get_papers_attribute(self, paper_id_list, attribute_name):
|
| 83 |
+
query = """
|
| 84 |
+
UNWIND $paper_ids AS paper_id
|
| 85 |
+
MATCH (p:Paper {hash_id: paper_id})
|
| 86 |
+
RETURN p.hash_id AS hash_id, p[$attribute_name] AS attributeValue
|
| 87 |
+
"""
|
| 88 |
+
with self.driver.session() as session:
|
| 89 |
+
result = session.execute_read(
|
| 90 |
+
lambda tx: tx.run(
|
| 91 |
+
query, paper_ids=paper_id_list, attribute_name=attribute_name
|
| 92 |
+
).data()
|
| 93 |
+
)
|
| 94 |
+
paper_attributes = [record["attributeValue"] for record in result]
|
| 95 |
+
return paper_attributes
|
| 96 |
+
|
| 97 |
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
| 98 |
query = f"""
|
| 99 |
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
src/utils/paper_retriever.py
CHANGED
|
@@ -184,12 +184,11 @@ class Retriever(object):
|
|
| 184 |
self, embedding, related_paper_id_list, type_name="embedding"
|
| 185 |
):
|
| 186 |
score_1 = np.zeros((len(related_paper_id_list)))
|
| 187 |
-
score_2 = np.zeros((len(related_paper_id_list)))
|
| 188 |
origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
|
| 189 |
-
context_embeddings =
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
]
|
| 193 |
if len(context_embeddings) > 0:
|
| 194 |
context_embeddings = torch.tensor(context_embeddings).to(self.device)
|
| 195 |
score_1 = torch.nn.functional.cosine_similarity(
|
|
@@ -198,8 +197,9 @@ class Retriever(object):
|
|
| 198 |
score_1 = score_1.cpu().numpy()
|
| 199 |
if self.config.RETRIEVE.need_normalize:
|
| 200 |
score_1 = score_1 / np.max(score_1)
|
| 201 |
-
|
| 202 |
-
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
|
|
|
| 203 |
score_all_dict = dict(
|
| 204 |
zip(
|
| 205 |
related_paper_id_list,
|
|
@@ -207,7 +207,8 @@ class Retriever(object):
|
|
| 207 |
+ score_2 * self.config.RETRIEVE.beta,
|
| 208 |
)
|
| 209 |
)
|
| 210 |
-
|
|
|
|
| 211 |
|
| 212 |
def filter_related_paper(self, score_dict, top_k):
|
| 213 |
if len(score_dict) <= top_k:
|
|
|
|
| 184 |
self, embedding, related_paper_id_list, type_name="embedding"
|
| 185 |
):
|
| 186 |
score_1 = np.zeros((len(related_paper_id_list)))
|
| 187 |
+
# score_2 = np.zeros((len(related_paper_id_list)))
|
| 188 |
origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
|
| 189 |
+
context_embeddings = self.paper_client.get_papers_attribute(
|
| 190 |
+
related_paper_id_list, type_name
|
| 191 |
+
)
|
|
|
|
| 192 |
if len(context_embeddings) > 0:
|
| 193 |
context_embeddings = torch.tensor(context_embeddings).to(self.device)
|
| 194 |
score_1 = torch.nn.functional.cosine_similarity(
|
|
|
|
| 197 |
score_1 = score_1.cpu().numpy()
|
| 198 |
if self.config.RETRIEVE.need_normalize:
|
| 199 |
score_1 = score_1 / np.max(score_1)
|
| 200 |
+
score_all_dict = dict(zip(related_paper_id_list, score_1))
|
| 201 |
+
# score_en_dict = dict(zip(related_paper_id_list, score_2))
|
| 202 |
+
"""
|
| 203 |
score_all_dict = dict(
|
| 204 |
zip(
|
| 205 |
related_paper_id_list,
|
|
|
|
| 207 |
+ score_2 * self.config.RETRIEVE.beta,
|
| 208 |
)
|
| 209 |
)
|
| 210 |
+
"""
|
| 211 |
+
return {}, {}, score_all_dict
|
| 212 |
|
| 213 |
def filter_related_paper(self, score_dict, top_k):
|
| 214 |
if len(score_dict) <= top_k:
|