fix bug
Browse files- configs/datasets.yaml +2 -2
- src/retriever.py +1 -1
- src/utils/paper_client.py +3 -1
- src/utils/paper_retriever.py +5 -13
configs/datasets.yaml
CHANGED
|
@@ -14,10 +14,10 @@ RETRIEVE:
|
|
| 14 |
use_cluster_to_filter: False # 过滤器中使用聚类算法
|
| 15 |
cite_type: "all_cite_id_list"
|
| 16 |
limit_num: 100 # 限制entity对应的paper数量
|
| 17 |
-
sn_num_for_entity:
|
| 18 |
kg_jump_num: 1 # 跳数
|
| 19 |
kg_cover_num: 3 # entity重合数量
|
| 20 |
-
sum_paper_num:
|
| 21 |
sn_retrieve_paper_num: 55 # 通过SN检索到的文章
|
| 22 |
cocite_top_k: 1
|
| 23 |
need_normalize: True
|
|
|
|
| 14 |
use_cluster_to_filter: False # 过滤器中使用聚类算法
|
| 15 |
cite_type: "all_cite_id_list"
|
| 16 |
limit_num: 100 # 限制entity对应的paper数量
|
| 17 |
+
sn_num_for_entity: 3 # SN搜索的文章数量,扩充entity
|
| 18 |
kg_jump_num: 1 # 跳数
|
| 19 |
kg_cover_num: 3 # entity重合数量
|
| 20 |
+
sum_paper_num: 50 # 最多检索到的paper数量
|
| 21 |
sn_retrieve_paper_num: 55 # 通过SN检索到的文章
|
| 22 |
cocite_top_k: 1
|
| 23 |
need_normalize: True
|
src/retriever.py
CHANGED
|
@@ -26,7 +26,7 @@ def main(ctx):
|
|
| 26 |
@click.option(
|
| 27 |
"-c",
|
| 28 |
"--config-path",
|
| 29 |
-
default="
|
| 30 |
type=click.File(),
|
| 31 |
required=True,
|
| 32 |
help="Dataset configuration file in YAML",
|
|
|
|
| 26 |
@click.option(
|
| 27 |
"-c",
|
| 28 |
"--config-path",
|
| 29 |
+
default="./configs/datasets.yaml",
|
| 30 |
type=click.File(),
|
| 31 |
required=True,
|
| 32 |
help="Dataset configuration file in YAML",
|
src/utils/paper_client.py
CHANGED
|
@@ -130,7 +130,6 @@ class PaperClient:
|
|
| 130 |
related_entities.add(entity)
|
| 131 |
|
| 132 |
return list(related_entities)
|
| 133 |
-
|
| 134 |
related_entities = bfs_query(entity_name, n, k)
|
| 135 |
if entity_name in related_entities:
|
| 136 |
related_entities.remove(entity_name)
|
|
@@ -541,6 +540,7 @@ class PaperClient:
|
|
| 541 |
data = {"nodes": [], "relationships": []}
|
| 542 |
query = """
|
| 543 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
|
|
|
| 544 |
RETURN p, e, r
|
| 545 |
"""
|
| 546 |
results = graph.run(query)
|
|
@@ -572,6 +572,7 @@ class PaperClient:
|
|
| 572 |
WHERE p.venue_name='acl' and p.year='2024'
|
| 573 |
RETURN p
|
| 574 |
"""
|
|
|
|
| 575 |
results = graph.run(query)
|
| 576 |
for record in tqdm(results):
|
| 577 |
paper_node = record["p"]
|
|
@@ -581,6 +582,7 @@ class PaperClient:
|
|
| 581 |
"label": "Paper",
|
| 582 |
"properties": dict(paper_node)
|
| 583 |
})
|
|
|
|
| 584 |
# 去除重复节点
|
| 585 |
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
| 586 |
unique_nodes = []
|
|
|
|
| 130 |
related_entities.add(entity)
|
| 131 |
|
| 132 |
return list(related_entities)
|
|
|
|
| 133 |
related_entities = bfs_query(entity_name, n, k)
|
| 134 |
if entity_name in related_entities:
|
| 135 |
related_entities.remove(entity_name)
|
|
|
|
| 540 |
data = {"nodes": [], "relationships": []}
|
| 541 |
query = """
|
| 542 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
| 543 |
+
WHERE p.venue_name='iclr' and p.year='2024'
|
| 544 |
RETURN p, e, r
|
| 545 |
"""
|
| 546 |
results = graph.run(query)
|
|
|
|
| 572 |
WHERE p.venue_name='acl' and p.year='2024'
|
| 573 |
RETURN p
|
| 574 |
"""
|
| 575 |
+
"""
|
| 576 |
results = graph.run(query)
|
| 577 |
for record in tqdm(results):
|
| 578 |
paper_node = record["p"]
|
|
|
|
| 582 |
"label": "Paper",
|
| 583 |
"properties": dict(paper_node)
|
| 584 |
})
|
| 585 |
+
"""
|
| 586 |
# 去除重复节点
|
| 587 |
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
| 588 |
unique_nodes = []
|
src/utils/paper_retriever.py
CHANGED
|
@@ -124,7 +124,7 @@ class Retriever(object):
|
|
| 124 |
)
|
| 125 |
sum_paper_num = 0
|
| 126 |
for key, value in entity_paper_num_dict.items():
|
| 127 |
-
if sum_paper_num <=
|
| 128 |
sum_paper_num += value
|
| 129 |
new_entities.append(key)
|
| 130 |
elif (
|
|
@@ -188,35 +188,27 @@ class Retriever(object):
|
|
| 188 |
return similarity
|
| 189 |
|
| 190 |
def cal_related_score(
|
| 191 |
-
self, context, related_paper_id_list, entities=None, type_name="
|
| 192 |
):
|
| 193 |
score_1 = np.zeros((len(related_paper_id_list)))
|
| 194 |
score_2 = np.zeros((len(related_paper_id_list)))
|
| 195 |
if entities is None:
|
| 196 |
entities = self.api_helper.generate_entity_list(context)
|
| 197 |
-
logger.debug("get entity from context: {}".format(entities))
|
| 198 |
origin_vector = self.embedding_model.encode(
|
| 199 |
context, convert_to_tensor=True, device=self.device
|
| 200 |
).unsqueeze(0)
|
| 201 |
-
|
| 202 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
| 203 |
for paper_id in related_paper_id_list
|
| 204 |
]
|
| 205 |
-
if len(
|
| 206 |
-
context_embeddings =
|
| 207 |
-
related_contexts,
|
| 208 |
-
batch_size=512,
|
| 209 |
-
convert_to_tensor=True,
|
| 210 |
-
device=self.device,
|
| 211 |
-
)
|
| 212 |
score_1 = torch.nn.functional.cosine_similarity(
|
| 213 |
origin_vector, context_embeddings
|
| 214 |
)
|
| 215 |
score_1 = score_1.cpu().numpy()
|
| 216 |
if self.config.RETRIEVE.need_normalize:
|
| 217 |
score_1 = score_1 / np.max(score_1)
|
| 218 |
-
# score_2 not enable
|
| 219 |
-
# if self.config.RETRIEVE.beta != 0:
|
| 220 |
score_sn_dict = dict(zip(related_paper_id_list, score_1))
|
| 221 |
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
| 222 |
score_all_dict = dict(
|
|
|
|
| 124 |
)
|
| 125 |
sum_paper_num = 0
|
| 126 |
for key, value in entity_paper_num_dict.items():
|
| 127 |
+
if sum_paper_num <= self.config.RETRIEVE.sum_paper_num:
|
| 128 |
sum_paper_num += value
|
| 129 |
new_entities.append(key)
|
| 130 |
elif (
|
|
|
|
| 188 |
return similarity
|
| 189 |
|
| 190 |
def cal_related_score(
|
| 191 |
+
self, context, related_paper_id_list, entities=None, type_name="embedding"
|
| 192 |
):
|
| 193 |
score_1 = np.zeros((len(related_paper_id_list)))
|
| 194 |
score_2 = np.zeros((len(related_paper_id_list)))
|
| 195 |
if entities is None:
|
| 196 |
entities = self.api_helper.generate_entity_list(context)
|
|
|
|
| 197 |
origin_vector = self.embedding_model.encode(
|
| 198 |
context, convert_to_tensor=True, device=self.device
|
| 199 |
).unsqueeze(0)
|
| 200 |
+
context_embeddings = [
|
| 201 |
self.paper_client.get_paper_attribute(paper_id, type_name)
|
| 202 |
for paper_id in related_paper_id_list
|
| 203 |
]
|
| 204 |
+
if len(context_embeddings) > 0:
|
| 205 |
+
context_embeddings = torch.tensor(context_embeddings).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
score_1 = torch.nn.functional.cosine_similarity(
|
| 207 |
origin_vector, context_embeddings
|
| 208 |
)
|
| 209 |
score_1 = score_1.cpu().numpy()
|
| 210 |
if self.config.RETRIEVE.need_normalize:
|
| 211 |
score_1 = score_1 / np.max(score_1)
|
|
|
|
|
|
|
| 212 |
score_sn_dict = dict(zip(related_paper_id_list, score_1))
|
| 213 |
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
| 214 |
score_all_dict = dict(
|