speed up
Browse files- src/utils/paper_retriever.py +13 -13
src/utils/paper_retriever.py
CHANGED
|
@@ -188,15 +188,11 @@ class Retriever(object):
|
|
| 188 |
return similarity
|
| 189 |
|
| 190 |
def cal_related_score(
|
| 191 |
-
self,
|
| 192 |
):
|
| 193 |
score_1 = np.zeros((len(related_paper_id_list)))
|
| 194 |
score_2 = np.zeros((len(related_paper_id_list)))
|
| 195 |
-
|
| 196 |
-
entities = self.api_helper.generate_entity_list(context)
|
| 197 |
-
origin_vector = self.embedding_model.encode(
|
| 198 |
-
context, convert_to_tensor=True, device=self.device
|
| 199 |
-
).unsqueeze(0)
|
| 200 |
context_embeddings = [
|
| 201 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
| 202 |
for paper_id in related_paper_id_list
|
|
@@ -275,11 +271,10 @@ class Retriever(object):
|
|
| 275 |
break
|
| 276 |
return paper_id_list
|
| 277 |
|
| 278 |
-
def cosine_similarity_search(self,
|
| 279 |
"""
|
| 280 |
return related paper: list
|
| 281 |
"""
|
| 282 |
-
embedding = self.embedding_model.encode(context)
|
| 283 |
result = self.paper_client.cosine_similarity_search(
|
| 284 |
embedding, k, type_name=type_name
|
| 285 |
)
|
|
@@ -506,8 +501,9 @@ class SNRetriever(Retriever):
|
|
| 506 |
|
| 507 |
def retrieve_paper(self, bg):
|
| 508 |
entities = []
|
|
|
|
| 509 |
sn_paper_id_list = self.cosine_similarity_search(
|
| 510 |
-
|
| 511 |
k=self.config.RETRIEVE.sn_retrieve_paper_num,
|
| 512 |
)
|
| 513 |
related_paper = set()
|
|
@@ -524,6 +520,7 @@ class SNRetriever(Retriever):
|
|
| 524 |
related_paper = list(related_paper)
|
| 525 |
logger.debug(f"paper num before filter: {len(related_paper)}")
|
| 526 |
result = {
|
|
|
|
| 527 |
"paper": related_paper,
|
| 528 |
"entities": entities,
|
| 529 |
"cocite_paper": list(cocite_id_set),
|
|
@@ -548,7 +545,7 @@ class SNRetriever(Retriever):
|
|
| 548 |
related_paper_id_list = retrieve_result["paper"]
|
| 549 |
retrieve_paper_num = len(related_paper_id_list)
|
| 550 |
_, _, score_all_dict = self.cal_related_score(
|
| 551 |
-
|
| 552 |
)
|
| 553 |
top_k_matrix = {}
|
| 554 |
recall = 0
|
|
@@ -626,8 +623,9 @@ class KGRetriever(Retriever):
|
|
| 626 |
retrieve_result = self.retrieve_paper(entities)
|
| 627 |
related_paper_id_list = retrieve_result["paper"]
|
| 628 |
retrieve_paper_num = len(related_paper_id_list)
|
|
|
|
| 629 |
_, _, score_all_dict = self.cal_related_score(
|
| 630 |
-
|
| 631 |
)
|
| 632 |
top_k_matrix = {}
|
| 633 |
recall = 0
|
|
@@ -668,8 +666,9 @@ class SNKGRetriever(Retriever):
|
|
| 668 |
|
| 669 |
def retrieve_paper(self, bg, entities):
|
| 670 |
sn_entities = []
|
|
|
|
| 671 |
sn_paper_id_list = self.cosine_similarity_search(
|
| 672 |
-
|
| 673 |
)
|
| 674 |
related_paper = set()
|
| 675 |
related_paper.update(sn_paper_id_list)
|
|
@@ -689,6 +688,7 @@ class SNKGRetriever(Retriever):
|
|
| 689 |
related_paper = related_paper.union(cocite_id_set)
|
| 690 |
related_paper = list(related_paper)
|
| 691 |
result = {
|
|
|
|
| 692 |
"paper": related_paper,
|
| 693 |
"entities": entities,
|
| 694 |
"cocite_paper": list(cocite_id_set),
|
|
@@ -717,7 +717,7 @@ class SNKGRetriever(Retriever):
|
|
| 717 |
retrieve_paper_num = len(related_paper_id_list)
|
| 718 |
logger.info("=== Begin cal related paper score ===")
|
| 719 |
_, _, score_all_dict = self.cal_related_score(
|
| 720 |
-
|
| 721 |
)
|
| 722 |
logger.info("=== End cal related paper score ===")
|
| 723 |
top_k_matrix = {}
|
|
|
|
| 188 |
return similarity
|
| 189 |
|
| 190 |
def cal_related_score(
|
| 191 |
+
self, embedding, related_paper_id_list, type_name="embedding"
|
| 192 |
):
|
| 193 |
score_1 = np.zeros((len(related_paper_id_list)))
|
| 194 |
score_2 = np.zeros((len(related_paper_id_list)))
|
| 195 |
+
origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
context_embeddings = [
|
| 197 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
| 198 |
for paper_id in related_paper_id_list
|
|
|
|
| 271 |
break
|
| 272 |
return paper_id_list
|
| 273 |
|
| 274 |
+
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
|
| 275 |
"""
|
| 276 |
return related paper: list
|
| 277 |
"""
|
|
|
|
| 278 |
result = self.paper_client.cosine_similarity_search(
|
| 279 |
embedding, k, type_name=type_name
|
| 280 |
)
|
|
|
|
| 501 |
|
| 502 |
def retrieve_paper(self, bg):
|
| 503 |
entities = []
|
| 504 |
+
embedding = self.embedding_model.encode(bg, device=self.device)
|
| 505 |
sn_paper_id_list = self.cosine_similarity_search(
|
| 506 |
+
embedding=embedding,
|
| 507 |
k=self.config.RETRIEVE.sn_retrieve_paper_num,
|
| 508 |
)
|
| 509 |
related_paper = set()
|
|
|
|
| 520 |
related_paper = list(related_paper)
|
| 521 |
logger.debug(f"paper num before filter: {len(related_paper)}")
|
| 522 |
result = {
|
| 523 |
+
"embedding": embedding,
|
| 524 |
"paper": related_paper,
|
| 525 |
"entities": entities,
|
| 526 |
"cocite_paper": list(cocite_id_set),
|
|
|
|
| 545 |
related_paper_id_list = retrieve_result["paper"]
|
| 546 |
retrieve_paper_num = len(related_paper_id_list)
|
| 547 |
_, _, score_all_dict = self.cal_related_score(
|
| 548 |
+
retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
|
| 549 |
)
|
| 550 |
top_k_matrix = {}
|
| 551 |
recall = 0
|
|
|
|
| 623 |
retrieve_result = self.retrieve_paper(entities)
|
| 624 |
related_paper_id_list = retrieve_result["paper"]
|
| 625 |
retrieve_paper_num = len(related_paper_id_list)
|
| 626 |
+
embedding = self.embedding_model.encode(bg, device=self.device)
|
| 627 |
_, _, score_all_dict = self.cal_related_score(
|
| 628 |
+
embedding, related_paper_id_list=related_paper_id_list
|
| 629 |
)
|
| 630 |
top_k_matrix = {}
|
| 631 |
recall = 0
|
|
|
|
| 666 |
|
| 667 |
def retrieve_paper(self, bg, entities):
|
| 668 |
sn_entities = []
|
| 669 |
+
embedding = self.embedding_model.encode(bg, device=self.device)
|
| 670 |
sn_paper_id_list = self.cosine_similarity_search(
|
| 671 |
+
embedding, k=self.config.RETRIEVE.sn_num_for_entity
|
| 672 |
)
|
| 673 |
related_paper = set()
|
| 674 |
related_paper.update(sn_paper_id_list)
|
|
|
|
| 688 |
related_paper = related_paper.union(cocite_id_set)
|
| 689 |
related_paper = list(related_paper)
|
| 690 |
result = {
|
| 691 |
+
"embedding": embedding,
|
| 692 |
"paper": related_paper,
|
| 693 |
"entities": entities,
|
| 694 |
"cocite_paper": list(cocite_id_set),
|
|
|
|
| 717 |
retrieve_paper_num = len(related_paper_id_list)
|
| 718 |
logger.info("=== Begin cal related paper score ===")
|
| 719 |
_, _, score_all_dict = self.cal_related_score(
|
| 720 |
+
retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
|
| 721 |
)
|
| 722 |
logger.info("=== End cal related paper score ===")
|
| 723 |
top_k_matrix = {}
|