reduce neo4j query time in retrieve
Browse files- app.py +19 -12
- src/app_pages/button_interface.py +21 -10
- src/generator.py +99 -93
- src/paper_manager.py +7 -6
- src/retriever.py +2 -2
- src/utils/api/__init__.py +2 -0
- src/utils/api/base_helper.py +70 -18
- src/utils/api/local_helper.py +39 -0
- src/utils/hash.py +35 -10
- src/utils/llms_api.py +31 -26
- src/utils/paper_client.py +480 -200
- src/utils/paper_retriever.py +12 -21
app.py
CHANGED
|
@@ -1,25 +1,32 @@
|
|
| 1 |
import sys
|
|
|
|
| 2 |
sys.path.append("./src")
|
| 3 |
import streamlit as st
|
| 4 |
-
from app_pages import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from app_pages.locale import _
|
| 6 |
-
from utils.hash import check_env, check_embedding
|
| 7 |
|
| 8 |
if __name__ == "__main__":
|
| 9 |
-
check_env()
|
| 10 |
-
check_embedding()
|
| 11 |
backend = button_interface.Backend()
|
| 12 |
# backend = None
|
| 13 |
st.set_page_config(layout="wide")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
def fn1():
|
| 17 |
one_click_generation.one_click_generation(backend)
|
|
|
|
| 18 |
def fn2():
|
| 19 |
step_by_step_generation.step_by_step_generation(backend)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
+
|
| 3 |
sys.path.append("./src")
|
| 4 |
import streamlit as st
|
| 5 |
+
from app_pages import (
|
| 6 |
+
button_interface,
|
| 7 |
+
step_by_step_generation,
|
| 8 |
+
one_click_generation,
|
| 9 |
+
homepage,
|
| 10 |
+
)
|
| 11 |
from app_pages.locale import _
|
|
|
|
| 12 |
|
| 13 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 14 |
backend = button_interface.Backend()
|
| 15 |
# backend = None
|
| 16 |
st.set_page_config(layout="wide")
|
| 17 |
+
|
| 18 |
+
# st.logo("./assets/pic/logo.jpg", size="large")
|
| 19 |
def fn1():
|
| 20 |
one_click_generation.one_click_generation(backend)
|
| 21 |
+
|
| 22 |
def fn2():
|
| 23 |
step_by_step_generation.step_by_step_generation(backend)
|
| 24 |
+
|
| 25 |
+
pg = st.navigation(
|
| 26 |
+
[
|
| 27 |
+
st.Page(homepage.home_page, title=_("🏠️ Homepage")),
|
| 28 |
+
st.Page(fn1, title=_("💧 One-click Generation")),
|
| 29 |
+
st.Page(fn2, title=_("💦 Step-by-step Generation")),
|
| 30 |
+
]
|
| 31 |
+
)
|
| 32 |
+
pg.run()
|
src/app_pages/button_interface.py
CHANGED
|
@@ -2,8 +2,10 @@ import json
|
|
| 2 |
from utils.paper_retriever import RetrieverFactory
|
| 3 |
from utils.llms_api import APIHelper
|
| 4 |
from utils.header import ConfigReader
|
|
|
|
| 5 |
from generator import IdeaGenerator
|
| 6 |
|
|
|
|
| 7 |
class Backend(object):
|
| 8 |
def __init__(self) -> None:
|
| 9 |
CONFIG_PATH = "./configs/datasets.yaml"
|
|
@@ -12,11 +14,14 @@ class Backend(object):
|
|
| 12 |
BRAINSTORM_MODE = "mode_c"
|
| 13 |
|
| 14 |
self.config = ConfigReader.load(CONFIG_PATH)
|
|
|
|
|
|
|
| 15 |
RETRIEVER_NAME = self.config.RETRIEVE.retriever_name
|
| 16 |
self.api_helper = APIHelper(self.config)
|
| 17 |
-
self.retriever_factory =
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
)
|
| 21 |
self.idea_generator = IdeaGenerator(self.config, None)
|
| 22 |
self.use_inspiration = USE_INSPIRATION
|
|
@@ -33,14 +38,14 @@ class Backend(object):
|
|
| 33 |
return []
|
| 34 |
|
| 35 |
def background2brainstorm_callback(self, background, json_strs=None):
|
| 36 |
-
if json_strs is not None:
|
| 37 |
json_contents = json.loads(json_strs)
|
| 38 |
return json_contents["brainstorm"]
|
| 39 |
else:
|
| 40 |
return self.api_helper.generate_brainstorm(background)
|
| 41 |
|
| 42 |
def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
|
| 43 |
-
if json_strs is not None:
|
| 44 |
json_contents = json.loads(json_strs)
|
| 45 |
entities_bg = json_contents["entities_bg"]
|
| 46 |
entities_bs = json_contents["entities_bs"]
|
|
@@ -71,13 +76,17 @@ class Backend(object):
|
|
| 71 |
for i, p in enumerate(result["related_paper"]):
|
| 72 |
res.append(str(p))
|
| 73 |
else:
|
| 74 |
-
result = self.retriever_factory.retrieve(
|
|
|
|
|
|
|
| 75 |
res = []
|
| 76 |
for i, p in enumerate(result["related_paper"]):
|
| 77 |
res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
|
| 78 |
return res, result["related_paper"]
|
| 79 |
|
| 80 |
-
def literature2initial_ideas_callback(
|
|
|
|
|
|
|
| 81 |
if json_strs is not None:
|
| 82 |
json_contents = json.loads(json_strs)
|
| 83 |
return json_contents["median"]["initial_idea"]
|
|
@@ -86,15 +95,16 @@ class Backend(object):
|
|
| 86 |
self.idea_generator.brainstorm = brainstorms
|
| 87 |
if self.use_inspiration:
|
| 88 |
message_input, idea_modified, median = (
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
)
|
| 92 |
else:
|
| 93 |
message_input, idea_modified, median = self.idea_generator.generate(
|
| 94 |
background, "new_idea", self.brainstorm_mode, False
|
| 95 |
)
|
| 96 |
return median["initial_idea"], idea_modified
|
| 97 |
-
|
| 98 |
def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
|
| 99 |
if json_strs is not None:
|
| 100 |
json_contents = json.loads(json_strs)
|
|
@@ -107,6 +117,7 @@ class Backend(object):
|
|
| 107 |
return self.examples[i].get("background", "Background not found.")
|
| 108 |
else:
|
| 109 |
return "Example not found. Please select a valid index."
|
|
|
|
| 110 |
# return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
|
| 111 |
# "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
|
| 112 |
# "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
|
|
|
|
| 2 |
from utils.paper_retriever import RetrieverFactory
|
| 3 |
from utils.llms_api import APIHelper
|
| 4 |
from utils.header import ConfigReader
|
| 5 |
+
from utils.hash import check_env, check_embedding
|
| 6 |
from generator import IdeaGenerator
|
| 7 |
|
| 8 |
+
|
| 9 |
class Backend(object):
|
| 10 |
def __init__(self) -> None:
|
| 11 |
CONFIG_PATH = "./configs/datasets.yaml"
|
|
|
|
| 14 |
BRAINSTORM_MODE = "mode_c"
|
| 15 |
|
| 16 |
self.config = ConfigReader.load(CONFIG_PATH)
|
| 17 |
+
check_env()
|
| 18 |
+
check_embedding(self.config.DEFAULT.embedding)
|
| 19 |
RETRIEVER_NAME = self.config.RETRIEVE.retriever_name
|
| 20 |
self.api_helper = APIHelper(self.config)
|
| 21 |
+
self.retriever_factory = (
|
| 22 |
+
RetrieverFactory.get_retriever_factory().create_retriever(
|
| 23 |
+
RETRIEVER_NAME, self.config
|
| 24 |
+
)
|
| 25 |
)
|
| 26 |
self.idea_generator = IdeaGenerator(self.config, None)
|
| 27 |
self.use_inspiration = USE_INSPIRATION
|
|
|
|
| 38 |
return []
|
| 39 |
|
| 40 |
def background2brainstorm_callback(self, background, json_strs=None):
|
| 41 |
+
if json_strs is not None: # only for DEBUG_MODE
|
| 42 |
json_contents = json.loads(json_strs)
|
| 43 |
return json_contents["brainstorm"]
|
| 44 |
else:
|
| 45 |
return self.api_helper.generate_brainstorm(background)
|
| 46 |
|
| 47 |
def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
|
| 48 |
+
if json_strs is not None: # only for DEBUG_MODE
|
| 49 |
json_contents = json.loads(json_strs)
|
| 50 |
entities_bg = json_contents["entities_bg"]
|
| 51 |
entities_bs = json_contents["entities_bs"]
|
|
|
|
| 76 |
for i, p in enumerate(result["related_paper"]):
|
| 77 |
res.append(str(p))
|
| 78 |
else:
|
| 79 |
+
result = self.retriever_factory.retrieve(
|
| 80 |
+
background, entities, need_evaluate=False, target_paper_id_list=[]
|
| 81 |
+
)
|
| 82 |
res = []
|
| 83 |
for i, p in enumerate(result["related_paper"]):
|
| 84 |
res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
|
| 85 |
return res, result["related_paper"]
|
| 86 |
|
| 87 |
+
def literature2initial_ideas_callback(
|
| 88 |
+
self, background, brainstorms, retrieved_literature, json_strs=None
|
| 89 |
+
):
|
| 90 |
if json_strs is not None:
|
| 91 |
json_contents = json.loads(json_strs)
|
| 92 |
return json_contents["median"]["initial_idea"]
|
|
|
|
| 95 |
self.idea_generator.brainstorm = brainstorms
|
| 96 |
if self.use_inspiration:
|
| 97 |
message_input, idea_modified, median = (
|
| 98 |
+
self.idea_generator.generate_by_inspiration(
|
| 99 |
+
background, "new_idea", self.brainstorm_mode, False
|
| 100 |
+
)
|
| 101 |
)
|
| 102 |
else:
|
| 103 |
message_input, idea_modified, median = self.idea_generator.generate(
|
| 104 |
background, "new_idea", self.brainstorm_mode, False
|
| 105 |
)
|
| 106 |
return median["initial_idea"], idea_modified
|
| 107 |
+
|
| 108 |
def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
|
| 109 |
if json_strs is not None:
|
| 110 |
json_contents = json.loads(json_strs)
|
|
|
|
| 117 |
return self.examples[i].get("background", "Background not found.")
|
| 118 |
else:
|
| 119 |
return "Example not found. Please select a valid index."
|
| 120 |
+
|
| 121 |
# return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
|
| 122 |
# "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
|
| 123 |
# "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
|
src/generator.py
CHANGED
|
@@ -10,6 +10,7 @@ import warnings
|
|
| 10 |
import time
|
| 11 |
import os
|
| 12 |
from utils.hash import check_env, check_embedding
|
|
|
|
| 13 |
warnings.filterwarnings("ignore")
|
| 14 |
|
| 15 |
|
|
@@ -24,9 +25,14 @@ def extract_problem(problem, background):
|
|
| 24 |
research_problem = background
|
| 25 |
return research_problem
|
| 26 |
|
|
|
|
| 27 |
class IdeaGenerator:
|
| 28 |
def __init__(
|
| 29 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
) -> None:
|
| 31 |
self.api_helper = APIHelper(config)
|
| 32 |
self.paper_list = paper_list
|
|
@@ -58,7 +64,9 @@ class IdeaGenerator:
|
|
| 58 |
idea = self.api_helper.generate_idea_with_cue_words(
|
| 59 |
problem, self.paper_list, self.cue_words
|
| 60 |
)
|
| 61 |
-
idea_filtered = self.api_helper.integrate_idea(
|
|
|
|
|
|
|
| 62 |
return message_input, problem, idea, idea_filtered
|
| 63 |
|
| 64 |
def generate_without_cue_words_bs(self, background: str):
|
|
@@ -66,7 +74,9 @@ class IdeaGenerator:
|
|
| 66 |
background, self.paper_list
|
| 67 |
)
|
| 68 |
idea = self.api_helper.generate_idea(problem, self.paper_list)
|
| 69 |
-
idea_filtered = self.api_helper.integrate_idea(
|
|
|
|
|
|
|
| 70 |
return message_input, problem, idea, idea_filtered
|
| 71 |
|
| 72 |
def generate_with_cue_words_ins(self, background: str):
|
|
@@ -93,16 +103,12 @@ class IdeaGenerator:
|
|
| 93 |
research_problem = extract_problem(problem, background)
|
| 94 |
inspirations = []
|
| 95 |
for paper in self.paper_list:
|
| 96 |
-
inspiration = self.api_helper.generate_inspiration(
|
| 97 |
-
research_problem, paper
|
| 98 |
-
)
|
| 99 |
inspirations.append(inspiration)
|
| 100 |
-
idea = self.api_helper.generate_idea_by_inspiration(
|
| 101 |
-
problem, inspirations
|
| 102 |
-
)
|
| 103 |
idea_filtered = self.api_helper.filter_idea(idea, background)
|
| 104 |
return message_input, problem, inspirations, idea, idea_filtered
|
| 105 |
-
|
| 106 |
def generate_with_cue_words_ins_bs(self, background: str):
|
| 107 |
problem, message_input = self.api_helper.generate_problem_with_cue_words(
|
| 108 |
background, self.paper_list, self.cue_words
|
|
@@ -117,7 +123,9 @@ class IdeaGenerator:
|
|
| 117 |
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
|
| 118 |
problem, inspirations, self.cue_words
|
| 119 |
)
|
| 120 |
-
idea_filtered = self.api_helper.integrate_idea(
|
|
|
|
|
|
|
| 121 |
return message_input, problem, inspirations, idea, idea_filtered
|
| 122 |
|
| 123 |
def generate_without_cue_words_ins_bs(self, background: str):
|
|
@@ -127,14 +135,12 @@ class IdeaGenerator:
|
|
| 127 |
research_problem = extract_problem(problem, background)
|
| 128 |
inspirations = []
|
| 129 |
for paper in self.paper_list:
|
| 130 |
-
inspiration = self.api_helper.generate_inspiration(
|
| 131 |
-
research_problem, paper
|
| 132 |
-
)
|
| 133 |
inspirations.append(inspiration)
|
| 134 |
-
idea = self.api_helper.generate_idea_by_inspiration(
|
| 135 |
-
|
|
|
|
| 136 |
)
|
| 137 |
-
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
|
| 138 |
return message_input, problem, inspirations, idea, idea_filtered
|
| 139 |
|
| 140 |
def generate(
|
|
@@ -151,44 +157,34 @@ class IdeaGenerator:
|
|
| 151 |
mode_name = "Generate new idea"
|
| 152 |
if bs_mode == "mode_a":
|
| 153 |
if use_cue_words:
|
| 154 |
-
logger.info(
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
idea,
|
| 159 |
-
idea_filtered
|
| 160 |
-
) = (
|
| 161 |
self.generate_with_cue_words(background)
|
| 162 |
)
|
| 163 |
else:
|
| 164 |
-
logger.info(
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
idea,
|
| 169 |
-
idea_filtered
|
| 170 |
-
) = (
|
| 171 |
self.generate_without_cue_words(background)
|
| 172 |
)
|
| 173 |
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
| 174 |
if use_cue_words:
|
| 175 |
-
logger.info(
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
idea,
|
| 180 |
-
idea_filtered
|
| 181 |
-
) = (
|
| 182 |
self.generate_with_cue_words_bs(background)
|
| 183 |
)
|
| 184 |
else:
|
| 185 |
-
logger.info(
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
) = (
|
| 192 |
self.generate_without_cue_words_bs(background)
|
| 193 |
)
|
| 194 |
|
|
@@ -214,48 +210,34 @@ class IdeaGenerator:
|
|
| 214 |
mode_name = "Generate new idea"
|
| 215 |
if bs_mode == "mode_a":
|
| 216 |
if use_cue_words:
|
| 217 |
-
logger.info(
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
inspirations,
|
| 222 |
-
idea,
|
| 223 |
-
idea_filtered
|
| 224 |
-
) = (
|
| 225 |
self.generate_with_cue_words_ins(background)
|
| 226 |
)
|
| 227 |
else:
|
| 228 |
-
logger.info(
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
inspirations,
|
| 233 |
-
idea,
|
| 234 |
-
idea_filtered
|
| 235 |
-
) = (
|
| 236 |
self.generate_without_cue_words_ins(background)
|
| 237 |
)
|
| 238 |
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
| 239 |
if use_cue_words:
|
| 240 |
-
logger.info(
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
inspirations,
|
| 245 |
-
idea,
|
| 246 |
-
idea_filtered
|
| 247 |
-
) = (
|
| 248 |
self.generate_with_cue_words_ins_bs(background)
|
| 249 |
)
|
| 250 |
else:
|
| 251 |
-
logger.info(
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
idea_filtered
|
| 258 |
-
) = (
|
| 259 |
self.generate_without_cue_words_ins_bs(background)
|
| 260 |
)
|
| 261 |
|
|
@@ -330,9 +312,18 @@ def main(ctx):
|
|
| 330 |
required=False,
|
| 331 |
help="The number of papers you want to process",
|
| 332 |
)
|
| 333 |
-
def backtracking(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
check_env()
|
| 335 |
-
check_embedding()
|
| 336 |
# Configuration
|
| 337 |
config = ConfigReader.load(config_path, **kwargs)
|
| 338 |
logger.add(
|
|
@@ -349,7 +340,10 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
| 349 |
batch_size = 2
|
| 350 |
output_dir = "./assets/output_idea/"
|
| 351 |
os.makedirs(output_dir, exist_ok=True)
|
| 352 |
-
output_file = os.path.join(
|
|
|
|
|
|
|
|
|
|
| 353 |
if os.path.exists(output_file):
|
| 354 |
with open(output_file, "r", encoding="utf-8") as f:
|
| 355 |
try:
|
|
@@ -388,7 +382,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
| 388 |
if brainstorm_mode == "mode_c":
|
| 389 |
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
| 390 |
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
| 391 |
-
entities_all = list(set(entities)|set(entities_bs))
|
| 392 |
else:
|
| 393 |
entities_bs = None
|
| 394 |
entities_all = entities
|
|
@@ -404,8 +398,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
| 404 |
continue
|
| 405 |
# 3. 检索相关论文
|
| 406 |
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
| 407 |
-
retriever_name,
|
| 408 |
-
config
|
| 409 |
)
|
| 410 |
result = rt.retrieve(
|
| 411 |
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
|
|
@@ -438,7 +431,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
| 438 |
"hash_id": paper["hash_id"],
|
| 439 |
"background": bg,
|
| 440 |
"entities_bg": entities,
|
| 441 |
-
"brainstorm"
|
| 442 |
"entities_bs": entities_bs,
|
| 443 |
"entities_rt": entities_rt,
|
| 444 |
"related_paper": [p["hash_id"] for p in related_paper],
|
|
@@ -467,6 +460,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
| 467 |
) as f:
|
| 468 |
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
| 469 |
|
|
|
|
| 470 |
@main.command()
|
| 471 |
@click.option(
|
| 472 |
"-c",
|
|
@@ -512,9 +506,16 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
|
|
| 512 |
required=False,
|
| 513 |
help="The number of data you want to process",
|
| 514 |
)
|
| 515 |
-
def new_idea(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
check_env()
|
| 517 |
-
check_embedding()
|
| 518 |
logger.add(
|
| 519 |
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
|
| 520 |
) # 添加文件输出
|
|
@@ -522,6 +523,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
| 522 |
# Configuration
|
| 523 |
config = ConfigReader.load(config_path, **kwargs)
|
| 524 |
api_helper = APIHelper(config)
|
|
|
|
| 525 |
eval_data = []
|
| 526 |
cur_num = 0
|
| 527 |
data_num = 0
|
|
@@ -529,7 +531,9 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
| 529 |
bg_ids = set()
|
| 530 |
output_dir = "./assets/output_idea/"
|
| 531 |
os.makedirs(output_dir, exist_ok=True)
|
| 532 |
-
output_file = os.path.join(
|
|
|
|
|
|
|
| 533 |
if os.path.exists(output_file):
|
| 534 |
with open(output_file, "r", encoding="utf-8") as f:
|
| 535 |
try:
|
|
@@ -538,7 +542,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
| 538 |
cur_num = len(eval_data)
|
| 539 |
except json.JSONDecodeError:
|
| 540 |
eval_data = []
|
| 541 |
-
|
| 542 |
for line in ids_path:
|
| 543 |
# 解析每行的JSON数据
|
| 544 |
data = json.loads(line)
|
|
@@ -568,16 +572,17 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
| 568 |
if brainstorm_mode == "mode_c":
|
| 569 |
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
| 570 |
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
| 571 |
-
entities_all = list(set(entities)|set(entities_bs))
|
| 572 |
else:
|
| 573 |
entities_bs = None
|
| 574 |
entities_all = entities
|
| 575 |
# 2. 检索相关论文
|
| 576 |
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
| 577 |
-
retriever_name,
|
| 578 |
-
|
|
|
|
|
|
|
| 579 |
)
|
| 580 |
-
result = rt.retrieve(bg, entities_all, need_evaluate=False, target_paper_id_list=[])
|
| 581 |
related_paper = result["related_paper"]
|
| 582 |
logger.info("Find {} related papers...".format(len(related_paper)))
|
| 583 |
entities_rt = result["entities"]
|
|
@@ -597,7 +602,7 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
| 597 |
{
|
| 598 |
"background": bg,
|
| 599 |
"entities_bg": entities,
|
| 600 |
-
"brainstorm"
|
| 601 |
"entities_bs": entities_bs,
|
| 602 |
"entities_rt": entities_rt,
|
| 603 |
"related_paper": [p["hash_id"] for p in related_paper],
|
|
@@ -621,5 +626,6 @@ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspira
|
|
| 621 |
with open(output_file, "w", encoding="utf-8") as f:
|
| 622 |
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
| 623 |
|
|
|
|
| 624 |
if __name__ == "__main__":
|
| 625 |
main()
|
|
|
|
| 10 |
import time
|
| 11 |
import os
|
| 12 |
from utils.hash import check_env, check_embedding
|
| 13 |
+
|
| 14 |
warnings.filterwarnings("ignore")
|
| 15 |
|
| 16 |
|
|
|
|
| 25 |
research_problem = background
|
| 26 |
return research_problem
|
| 27 |
|
| 28 |
+
|
| 29 |
class IdeaGenerator:
|
| 30 |
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
config,
|
| 33 |
+
paper_list: list[dict] = [],
|
| 34 |
+
cue_words: list = None,
|
| 35 |
+
brainstorm: str = None,
|
| 36 |
) -> None:
|
| 37 |
self.api_helper = APIHelper(config)
|
| 38 |
self.paper_list = paper_list
|
|
|
|
| 64 |
idea = self.api_helper.generate_idea_with_cue_words(
|
| 65 |
problem, self.paper_list, self.cue_words
|
| 66 |
)
|
| 67 |
+
idea_filtered = self.api_helper.integrate_idea(
|
| 68 |
+
background, self.brainstorm, idea
|
| 69 |
+
)
|
| 70 |
return message_input, problem, idea, idea_filtered
|
| 71 |
|
| 72 |
def generate_without_cue_words_bs(self, background: str):
|
|
|
|
| 74 |
background, self.paper_list
|
| 75 |
)
|
| 76 |
idea = self.api_helper.generate_idea(problem, self.paper_list)
|
| 77 |
+
idea_filtered = self.api_helper.integrate_idea(
|
| 78 |
+
background, self.brainstorm, idea
|
| 79 |
+
)
|
| 80 |
return message_input, problem, idea, idea_filtered
|
| 81 |
|
| 82 |
def generate_with_cue_words_ins(self, background: str):
|
|
|
|
| 103 |
research_problem = extract_problem(problem, background)
|
| 104 |
inspirations = []
|
| 105 |
for paper in self.paper_list:
|
| 106 |
+
inspiration = self.api_helper.generate_inspiration(research_problem, paper)
|
|
|
|
|
|
|
| 107 |
inspirations.append(inspiration)
|
| 108 |
+
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations)
|
|
|
|
|
|
|
| 109 |
idea_filtered = self.api_helper.filter_idea(idea, background)
|
| 110 |
return message_input, problem, inspirations, idea, idea_filtered
|
| 111 |
+
|
| 112 |
def generate_with_cue_words_ins_bs(self, background: str):
|
| 113 |
problem, message_input = self.api_helper.generate_problem_with_cue_words(
|
| 114 |
background, self.paper_list, self.cue_words
|
|
|
|
| 123 |
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
|
| 124 |
problem, inspirations, self.cue_words
|
| 125 |
)
|
| 126 |
+
idea_filtered = self.api_helper.integrate_idea(
|
| 127 |
+
background, self.brainstorm, idea
|
| 128 |
+
)
|
| 129 |
return message_input, problem, inspirations, idea, idea_filtered
|
| 130 |
|
| 131 |
def generate_without_cue_words_ins_bs(self, background: str):
|
|
|
|
| 135 |
research_problem = extract_problem(problem, background)
|
| 136 |
inspirations = []
|
| 137 |
for paper in self.paper_list:
|
| 138 |
+
inspiration = self.api_helper.generate_inspiration(research_problem, paper)
|
|
|
|
|
|
|
| 139 |
inspirations.append(inspiration)
|
| 140 |
+
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations)
|
| 141 |
+
idea_filtered = self.api_helper.integrate_idea(
|
| 142 |
+
background, self.brainstorm, idea
|
| 143 |
)
|
|
|
|
| 144 |
return message_input, problem, inspirations, idea, idea_filtered
|
| 145 |
|
| 146 |
def generate(
|
|
|
|
| 157 |
mode_name = "Generate new idea"
|
| 158 |
if bs_mode == "mode_a":
|
| 159 |
if use_cue_words:
|
| 160 |
+
logger.info(
|
| 161 |
+
"{} using brainstorm_mode_a with cue words.".format(mode_name)
|
| 162 |
+
)
|
| 163 |
+
(message_input, problem, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
|
| 164 |
self.generate_with_cue_words(background)
|
| 165 |
)
|
| 166 |
else:
|
| 167 |
+
logger.info(
|
| 168 |
+
"{} using brainstorm_mode_a without cue words.".format(mode_name)
|
| 169 |
+
)
|
| 170 |
+
(message_input, problem, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
|
| 171 |
self.generate_without_cue_words(background)
|
| 172 |
)
|
| 173 |
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
| 174 |
if use_cue_words:
|
| 175 |
+
logger.info(
|
| 176 |
+
"{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
|
| 177 |
+
)
|
| 178 |
+
(message_input, problem, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
|
| 179 |
self.generate_with_cue_words_bs(background)
|
| 180 |
)
|
| 181 |
else:
|
| 182 |
+
logger.info(
|
| 183 |
+
"{} using brainstorm_{} without cue words.".format(
|
| 184 |
+
mode_name, bs_mode
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
(message_input, problem, idea, idea_filtered) = (
|
|
|
|
| 188 |
self.generate_without_cue_words_bs(background)
|
| 189 |
)
|
| 190 |
|
|
|
|
| 210 |
mode_name = "Generate new idea"
|
| 211 |
if bs_mode == "mode_a":
|
| 212 |
if use_cue_words:
|
| 213 |
+
logger.info(
|
| 214 |
+
"{} using brainstorm_mode_a with cue words.".format(mode_name)
|
| 215 |
+
)
|
| 216 |
+
(message_input, problem, inspirations, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
self.generate_with_cue_words_ins(background)
|
| 218 |
)
|
| 219 |
else:
|
| 220 |
+
logger.info(
|
| 221 |
+
"{} using brainstorm_mode_a without cue words.".format(mode_name)
|
| 222 |
+
)
|
| 223 |
+
(message_input, problem, inspirations, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
self.generate_without_cue_words_ins(background)
|
| 225 |
)
|
| 226 |
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
| 227 |
if use_cue_words:
|
| 228 |
+
logger.info(
|
| 229 |
+
"{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
|
| 230 |
+
)
|
| 231 |
+
(message_input, problem, inspirations, idea, idea_filtered) = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
self.generate_with_cue_words_ins_bs(background)
|
| 233 |
)
|
| 234 |
else:
|
| 235 |
+
logger.info(
|
| 236 |
+
"{} using brainstorm_{} without cue words.".format(
|
| 237 |
+
mode_name, bs_mode
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
(message_input, problem, inspirations, idea, idea_filtered) = (
|
|
|
|
|
|
|
| 241 |
self.generate_without_cue_words_ins_bs(background)
|
| 242 |
)
|
| 243 |
|
|
|
|
| 312 |
required=False,
|
| 313 |
help="The number of papers you want to process",
|
| 314 |
)
|
| 315 |
+
def backtracking(
|
| 316 |
+
config_path,
|
| 317 |
+
ids_path,
|
| 318 |
+
retriever_name,
|
| 319 |
+
brainstorm_mode,
|
| 320 |
+
use_cue_words,
|
| 321 |
+
use_inspiration,
|
| 322 |
+
num,
|
| 323 |
+
**kwargs,
|
| 324 |
+
):
|
| 325 |
check_env()
|
| 326 |
+
check_embedding()
|
| 327 |
# Configuration
|
| 328 |
config = ConfigReader.load(config_path, **kwargs)
|
| 329 |
logger.add(
|
|
|
|
| 340 |
batch_size = 2
|
| 341 |
output_dir = "./assets/output_idea/"
|
| 342 |
os.makedirs(output_dir, exist_ok=True)
|
| 343 |
+
output_file = os.path.join(
|
| 344 |
+
output_dir,
|
| 345 |
+
f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json",
|
| 346 |
+
)
|
| 347 |
if os.path.exists(output_file):
|
| 348 |
with open(output_file, "r", encoding="utf-8") as f:
|
| 349 |
try:
|
|
|
|
| 382 |
if brainstorm_mode == "mode_c":
|
| 383 |
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
| 384 |
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
| 385 |
+
entities_all = list(set(entities) | set(entities_bs))
|
| 386 |
else:
|
| 387 |
entities_bs = None
|
| 388 |
entities_all = entities
|
|
|
|
| 398 |
continue
|
| 399 |
# 3. 检索相关论文
|
| 400 |
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
| 401 |
+
retriever_name, config
|
|
|
|
| 402 |
)
|
| 403 |
result = rt.retrieve(
|
| 404 |
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
|
|
|
|
| 431 |
"hash_id": paper["hash_id"],
|
| 432 |
"background": bg,
|
| 433 |
"entities_bg": entities,
|
| 434 |
+
"brainstorm": brainstorm,
|
| 435 |
"entities_bs": entities_bs,
|
| 436 |
"entities_rt": entities_rt,
|
| 437 |
"related_paper": [p["hash_id"] for p in related_paper],
|
|
|
|
| 460 |
) as f:
|
| 461 |
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
| 462 |
|
| 463 |
+
|
| 464 |
@main.command()
|
| 465 |
@click.option(
|
| 466 |
"-c",
|
|
|
|
| 506 |
required=False,
|
| 507 |
help="The number of data you want to process",
|
| 508 |
)
|
| 509 |
+
def new_idea(
|
| 510 |
+
config_path,
|
| 511 |
+
ids_path,
|
| 512 |
+
retriever_name,
|
| 513 |
+
brainstorm_mode,
|
| 514 |
+
use_inspiration,
|
| 515 |
+
num,
|
| 516 |
+
**kwargs,
|
| 517 |
+
):
|
| 518 |
check_env()
|
|
|
|
| 519 |
logger.add(
|
| 520 |
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
|
| 521 |
) # 添加文件输出
|
|
|
|
| 523 |
# Configuration
|
| 524 |
config = ConfigReader.load(config_path, **kwargs)
|
| 525 |
api_helper = APIHelper(config)
|
| 526 |
+
check_embedding(config.DEFAULT.embedding)
|
| 527 |
eval_data = []
|
| 528 |
cur_num = 0
|
| 529 |
data_num = 0
|
|
|
|
| 531 |
bg_ids = set()
|
| 532 |
output_dir = "./assets/output_idea/"
|
| 533 |
os.makedirs(output_dir, exist_ok=True)
|
| 534 |
+
output_file = os.path.join(
|
| 535 |
+
output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json"
|
| 536 |
+
)
|
| 537 |
if os.path.exists(output_file):
|
| 538 |
with open(output_file, "r", encoding="utf-8") as f:
|
| 539 |
try:
|
|
|
|
| 542 |
cur_num = len(eval_data)
|
| 543 |
except json.JSONDecodeError:
|
| 544 |
eval_data = []
|
| 545 |
+
logger.debug(f"{cur_num} datas have been processed.")
|
| 546 |
for line in ids_path:
|
| 547 |
# 解析每行的JSON数据
|
| 548 |
data = json.loads(line)
|
|
|
|
| 572 |
if brainstorm_mode == "mode_c":
|
| 573 |
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
| 574 |
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
| 575 |
+
entities_all = list(set(entities) | set(entities_bs))
|
| 576 |
else:
|
| 577 |
entities_bs = None
|
| 578 |
entities_all = entities
|
| 579 |
# 2. 检索相关论文
|
| 580 |
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
| 581 |
+
retriever_name, config
|
| 582 |
+
)
|
| 583 |
+
result = rt.retrieve(
|
| 584 |
+
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
|
| 585 |
)
|
|
|
|
| 586 |
related_paper = result["related_paper"]
|
| 587 |
logger.info("Find {} related papers...".format(len(related_paper)))
|
| 588 |
entities_rt = result["entities"]
|
|
|
|
| 602 |
{
|
| 603 |
"background": bg,
|
| 604 |
"entities_bg": entities,
|
| 605 |
+
"brainstorm": brainstorm,
|
| 606 |
"entities_bs": entities_bs,
|
| 607 |
"entities_rt": entities_rt,
|
| 608 |
"related_paper": [p["hash_id"] for p in related_paper],
|
|
|
|
| 626 |
with open(output_file, "w", encoding="utf-8") as f:
|
| 627 |
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
| 628 |
|
| 629 |
+
|
| 630 |
if __name__ == "__main__":
|
| 631 |
main()
|
src/paper_manager.py
CHANGED
|
@@ -389,10 +389,8 @@ class PaperManager:
|
|
| 389 |
)
|
| 390 |
|
| 391 |
if need_summary:
|
| 392 |
-
print(paper.keys())
|
| 393 |
if not self.check_parse(paper):
|
| 394 |
logger.error(f"paper {paper['hash_id']} need parse first...")
|
| 395 |
-
|
| 396 |
result = self.api_helper(
|
| 397 |
paper["title"], paper["abstract"], paper["introduction"]
|
| 398 |
)
|
|
@@ -628,9 +626,11 @@ class PaperManager:
|
|
| 628 |
|
| 629 |
def insert_embedding(self, hash_id=None):
|
| 630 |
self.paper_client.add_paper_abstract_embedding(self.embedding_model, hash_id)
|
| 631 |
-
# self.
|
| 632 |
-
# self.
|
| 633 |
-
#
|
|
|
|
|
|
|
| 634 |
|
| 635 |
def cosine_similarity_search(self, data_type, context, k=1):
|
| 636 |
"""
|
|
@@ -837,8 +837,9 @@ def local(config_path, year, venue_name, output, **kwargs):
|
|
| 837 |
os.makedirs(os.path.dirname(output_path))
|
| 838 |
config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
|
| 839 |
pm = PaperManager(config, venue_name, year)
|
|
|
|
| 840 |
pm.update_paper_from_json_to_json(
|
| 841 |
-
need_download=True, need_parse=True, need_summary=True
|
| 842 |
)
|
| 843 |
|
| 844 |
|
|
|
|
| 389 |
)
|
| 390 |
|
| 391 |
if need_summary:
|
|
|
|
| 392 |
if not self.check_parse(paper):
|
| 393 |
logger.error(f"paper {paper['hash_id']} need parse first...")
|
|
|
|
| 394 |
result = self.api_helper(
|
| 395 |
paper["title"], paper["abstract"], paper["introduction"]
|
| 396 |
)
|
|
|
|
| 626 |
|
| 627 |
def insert_embedding(self, hash_id=None):
|
| 628 |
self.paper_client.add_paper_abstract_embedding(self.embedding_model, hash_id)
|
| 629 |
+
# self.paper_client.add_paper_bg_embedding(self.embedding_model, hash_id)
|
| 630 |
+
# self.paper_client.add_paper_contribution_embedding(
|
| 631 |
+
# self.embedding_model, hash_id
|
| 632 |
+
# )
|
| 633 |
+
# self.paper_client.add_paper_summary_embedding(self.embedding_model, hash_id)
|
| 634 |
|
| 635 |
def cosine_similarity_search(self, data_type, context, k=1):
|
| 636 |
"""
|
|
|
|
| 837 |
os.makedirs(os.path.dirname(output_path))
|
| 838 |
config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
|
| 839 |
pm = PaperManager(config, venue_name, year)
|
| 840 |
+
print("###")
|
| 841 |
pm.update_paper_from_json_to_json(
|
| 842 |
+
need_download=True, need_parse=True, need_summary=True
|
| 843 |
)
|
| 844 |
|
| 845 |
|
src/retriever.py
CHANGED
|
@@ -41,9 +41,9 @@ def main(ctx):
|
|
| 41 |
def retrieve(
|
| 42 |
config_path, ids_path, **kwargs
|
| 43 |
):
|
| 44 |
-
check_env()
|
| 45 |
-
check_embedding()
|
| 46 |
config = ConfigReader.load(config_path, **kwargs)
|
|
|
|
|
|
|
| 47 |
log_dir = config.DEFAULT.log_dir
|
| 48 |
retriever_name = config.RETRIEVE.retriever_name
|
| 49 |
cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
|
|
|
|
| 41 |
def retrieve(
|
| 42 |
config_path, ids_path, **kwargs
|
| 43 |
):
|
|
|
|
|
|
|
| 44 |
config = ConfigReader.load(config_path, **kwargs)
|
| 45 |
+
check_embedding(config.DEFAULT.embedding)
|
| 46 |
+
check_env()
|
| 47 |
log_dir = config.DEFAULT.log_dir
|
| 48 |
retriever_name = config.RETRIEVE.retriever_name
|
| 49 |
cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
|
src/utils/api/__init__.py
CHANGED
|
@@ -22,8 +22,10 @@ Creation Date : 2024-10-29
|
|
| 22 |
|
| 23 |
Author : Frank Kang(frankkang@zju.edu.cn)
|
| 24 |
"""
|
|
|
|
| 25 |
from .base_helper import HelperCompany
|
| 26 |
from .openai_helper import OpenAIHelper # noqa: F401, ensure autoregister
|
| 27 |
from .zhipuai_helper import ZhipuAIHelper # noqa: F401, ensure autoregister
|
|
|
|
| 28 |
|
| 29 |
__all__ = ["HelperCompany"]
|
|
|
|
| 22 |
|
| 23 |
Author : Frank Kang(frankkang@zju.edu.cn)
|
| 24 |
"""
|
| 25 |
+
|
| 26 |
from .base_helper import HelperCompany
|
| 27 |
from .openai_helper import OpenAIHelper # noqa: F401, ensure autoregister
|
| 28 |
from .zhipuai_helper import ZhipuAIHelper # noqa: F401, ensure autoregister
|
| 29 |
+
from .local_helper import LocalHelper # noqa: F401, ensure autoregister
|
| 30 |
|
| 31 |
__all__ = ["HelperCompany"]
|
src/utils/api/base_helper.py
CHANGED
|
@@ -17,6 +17,9 @@ from abc import ABCMeta
|
|
| 17 |
from typing_extensions import Literal, override
|
| 18 |
from ..base_company import BaseCompany
|
| 19 |
from typing import Union
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class NotGiven:
|
|
@@ -109,6 +112,31 @@ class BaseHelper:
|
|
| 109 |
self.base_url = base_url
|
| 110 |
self.client = None
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
def create(
|
| 113 |
self,
|
| 114 |
*args,
|
|
@@ -124,7 +152,7 @@ class BaseHelper:
|
|
| 124 |
extra_headers: None | NotGiven = None,
|
| 125 |
extra_body: None | NotGiven = None,
|
| 126 |
timeout: float | None | NotGiven = None,
|
| 127 |
-
**kwargs
|
| 128 |
):
|
| 129 |
"""
|
| 130 |
Creates a model response for the given chat conversation.
|
|
@@ -187,20 +215,44 @@ class BaseHelper:
|
|
| 187 |
|
| 188 |
timeout: Override the client-level default timeout for this request, in seconds
|
| 189 |
"""
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from typing_extensions import Literal, override
|
| 18 |
from ..base_company import BaseCompany
|
| 19 |
from typing import Union
|
| 20 |
+
import requests
|
| 21 |
+
import json
|
| 22 |
+
from requests.exceptions import RequestException
|
| 23 |
|
| 24 |
|
| 25 |
class NotGiven:
|
|
|
|
| 112 |
self.base_url = base_url
|
| 113 |
self.client = None
|
| 114 |
|
| 115 |
+
def apply_for_service(self, data_param, max_attempts=4):
|
| 116 |
+
attempt = 0
|
| 117 |
+
while attempt < max_attempts:
|
| 118 |
+
try:
|
| 119 |
+
# print(f"尝试 #{attempt + 1}")
|
| 120 |
+
r = requests.post(
|
| 121 |
+
self.base_url + "/llm/generate",
|
| 122 |
+
headers={"Content-Type": "application/json"},
|
| 123 |
+
data=json.dumps(data_param),
|
| 124 |
+
)
|
| 125 |
+
# 检查请求是否成功
|
| 126 |
+
if r.status_code == 200:
|
| 127 |
+
# print("服务请求成功。")
|
| 128 |
+
response = r.json()["data"]["output"]
|
| 129 |
+
return response # 或者根据需要返回其他内容
|
| 130 |
+
else:
|
| 131 |
+
print("服务请求失败,响应状态码:", response.status_code)
|
| 132 |
+
except RequestException as e:
|
| 133 |
+
print("请求发生错误:", e)
|
| 134 |
+
|
| 135 |
+
attempt += 1
|
| 136 |
+
if attempt == max_attempts:
|
| 137 |
+
print("达到最大尝试次数,服务请求失败。")
|
| 138 |
+
return None # 或者根据你的情况抛出异常
|
| 139 |
+
|
| 140 |
def create(
|
| 141 |
self,
|
| 142 |
*args,
|
|
|
|
| 152 |
extra_headers: None | NotGiven = None,
|
| 153 |
extra_body: None | NotGiven = None,
|
| 154 |
timeout: float | None | NotGiven = None,
|
| 155 |
+
**kwargs,
|
| 156 |
):
|
| 157 |
"""
|
| 158 |
Creates a model response for the given chat conversation.
|
|
|
|
| 215 |
|
| 216 |
timeout: Override the client-level default timeout for this request, in seconds
|
| 217 |
"""
|
| 218 |
+
if self.model != "local":
|
| 219 |
+
return (
|
| 220 |
+
self.client.chat.completions.create(
|
| 221 |
+
*args,
|
| 222 |
+
model=self.model,
|
| 223 |
+
messages=messages,
|
| 224 |
+
stream=stream,
|
| 225 |
+
temperature=temperature,
|
| 226 |
+
top_p=top_p,
|
| 227 |
+
max_tokens=max_tokens,
|
| 228 |
+
seed=seed,
|
| 229 |
+
stop=stop,
|
| 230 |
+
tools=tools,
|
| 231 |
+
tool_choice=tool_choice,
|
| 232 |
+
extra_headers=extra_headers,
|
| 233 |
+
extra_body=extra_body,
|
| 234 |
+
timeout=timeout,
|
| 235 |
+
**kwargs,
|
| 236 |
+
)
|
| 237 |
+
.choices[0]
|
| 238 |
+
.message.content
|
| 239 |
+
)
|
| 240 |
+
else:
|
| 241 |
+
default_system = "You are a helpful assistant."
|
| 242 |
+
input_content = ""
|
| 243 |
+
for message in messages:
|
| 244 |
+
if message["role"] == "system":
|
| 245 |
+
default_system = message["content"]
|
| 246 |
+
else:
|
| 247 |
+
input_content += message["content"]
|
| 248 |
+
data_param = {}
|
| 249 |
+
data_param["input"] = input_content
|
| 250 |
+
data_param["serviceParams"] = {"stream": False, "system": default_system}
|
| 251 |
+
data_param["ModelParams"] = {
|
| 252 |
+
"temperature": 0.8,
|
| 253 |
+
"presence_penalty": 2.0,
|
| 254 |
+
"frequency_penalty": 0.0,
|
| 255 |
+
"top_p": 0.8,
|
| 256 |
+
}
|
| 257 |
+
response = self.apply_for_service(data_param)
|
| 258 |
+
return response
|
src/utils/api/local_helper.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""_summary_
|
| 2 |
+
-*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
Module : data.utils.api.zhipuai_helper
|
| 5 |
+
|
| 6 |
+
File Name : zhipuai_helper.py
|
| 7 |
+
|
| 8 |
+
Description : Helper class for ZhipuAI interface, generally not used directly.
|
| 9 |
+
For example:
|
| 10 |
+
```
|
| 11 |
+
from data.utils.api import HelperCompany
|
| 12 |
+
helper = HelperCompany.get()['ZhipuAI']
|
| 13 |
+
...
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Creation Date : 2024-11-28
|
| 17 |
+
|
| 18 |
+
Author : lihuigu(lihuigu@zju.edu.cn)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from .base_helper import register_helper, BaseHelper
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@register_helper("Local")
|
| 25 |
+
class LocalHelper(BaseHelper):
|
| 26 |
+
"""_summary_
|
| 27 |
+
|
| 28 |
+
Helper class for ZhipuAI interface, generally not used directly.
|
| 29 |
+
|
| 30 |
+
For example:
|
| 31 |
+
```
|
| 32 |
+
from data.utils.api import HelperCompany
|
| 33 |
+
helper = HelperCompany.get()['Local']
|
| 34 |
+
...
|
| 35 |
+
```
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, api_key, model, base_url=None, timeout=None):
|
| 39 |
+
super().__init__(api_key, model, base_url)
|
src/utils/hash.py
CHANGED
|
@@ -12,18 +12,35 @@ ENV_CHECKED = False
|
|
| 12 |
EMBEDDING_CHECKED = False
|
| 13 |
|
| 14 |
|
| 15 |
-
def check_embedding():
|
|
|
|
| 16 |
global EMBEDDING_CHECKED
|
| 17 |
if not EMBEDDING_CHECKED:
|
| 18 |
# Define the repository and files to download
|
| 19 |
-
repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
|
| 20 |
local_dir = f"./assets/model/{repo_id}"
|
| 21 |
-
|
| 22 |
-
"
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Download each file and save it to the /model/bge directory
|
| 28 |
for file_name in files_to_download:
|
| 29 |
if not os.path.exists(os.path.join(local_dir, file_name)):
|
|
@@ -47,12 +64,15 @@ def check_env():
|
|
| 47 |
"NEO4J_PASSWD",
|
| 48 |
"MODEL_NAME",
|
| 49 |
"MODEL_TYPE",
|
| 50 |
-
"MODEL_API_KEY",
|
| 51 |
"BASE_URL",
|
| 52 |
]
|
| 53 |
for env_name in env_name_list:
|
| 54 |
if env_name not in os.environ or os.environ[env_name] == "":
|
| 55 |
raise ValueError(f"{env_name} is not set...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
ENV_CHECKED = True
|
| 57 |
|
| 58 |
|
|
@@ -61,16 +81,21 @@ class EmbeddingModel:
|
|
| 61 |
|
| 62 |
def __new__(cls, config):
|
| 63 |
if cls._instance is None:
|
|
|
|
| 64 |
cls._instance = super(EmbeddingModel, cls).__new__(cls)
|
| 65 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 66 |
cls._instance.embedding_model = SentenceTransformer(
|
| 67 |
-
model_name_or_path=get_dir(
|
| 68 |
device=device,
|
|
|
|
| 69 |
)
|
| 70 |
print(f"==== using device {device} ====")
|
| 71 |
return cls._instance
|
| 72 |
|
|
|
|
| 73 |
def get_embedding_model(config):
|
|
|
|
|
|
|
| 74 |
return EmbeddingModel(config).embedding_model
|
| 75 |
|
| 76 |
|
|
|
|
| 12 |
EMBEDDING_CHECKED = False
|
| 13 |
|
| 14 |
|
| 15 |
+
def check_embedding(repo_id):
|
| 16 |
+
print("=== check embedding model ===")
|
| 17 |
global EMBEDDING_CHECKED
|
| 18 |
if not EMBEDDING_CHECKED:
|
| 19 |
# Define the repository and files to download
|
|
|
|
| 20 |
local_dir = f"./assets/model/{repo_id}"
|
| 21 |
+
if repo_id in [
|
| 22 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 23 |
+
"BAAI/bge-small-en-v1.5",
|
| 24 |
+
"BAAAI/llm_embedder",
|
| 25 |
+
]:
|
| 26 |
+
# repo_id = "sentence-transformers/all-MiniLM-L6-v2"
|
| 27 |
+
# repo_id = "BAAI/bge-small-en-v1.5"
|
| 28 |
+
files_to_download = [
|
| 29 |
+
"config.json",
|
| 30 |
+
"pytorch_model.bin",
|
| 31 |
+
"tokenizer_config.json",
|
| 32 |
+
"vocab.txt",
|
| 33 |
+
]
|
| 34 |
+
elif repo_id in ["Alibaba-NLP/gte-base-en-v1.5"]:
|
| 35 |
+
files_to_download = [
|
| 36 |
+
"config.json",
|
| 37 |
+
"model.safetensors",
|
| 38 |
+
"modules.json",
|
| 39 |
+
"tokenizer.json",
|
| 40 |
+
"sentence_bert_config.json",
|
| 41 |
+
"tokenizer_config.json",
|
| 42 |
+
"vocab.txt",
|
| 43 |
+
]
|
| 44 |
# Download each file and save it to the /model/bge directory
|
| 45 |
for file_name in files_to_download:
|
| 46 |
if not os.path.exists(os.path.join(local_dir, file_name)):
|
|
|
|
| 64 |
"NEO4J_PASSWD",
|
| 65 |
"MODEL_NAME",
|
| 66 |
"MODEL_TYPE",
|
|
|
|
| 67 |
"BASE_URL",
|
| 68 |
]
|
| 69 |
for env_name in env_name_list:
|
| 70 |
if env_name not in os.environ or os.environ[env_name] == "":
|
| 71 |
raise ValueError(f"{env_name} is not set...")
|
| 72 |
+
if os.environ["MODEL_TYPE"] != "Local":
|
| 73 |
+
env_name = "MODEL_API_KEY"
|
| 74 |
+
if env_name not in os.environ or os.environ[env_name] == "":
|
| 75 |
+
raise ValueError(f"{env_name} is not set...")
|
| 76 |
ENV_CHECKED = True
|
| 77 |
|
| 78 |
|
|
|
|
| 81 |
|
| 82 |
def __new__(cls, config):
|
| 83 |
if cls._instance is None:
|
| 84 |
+
local_dir = f"./assets/model/{config.DEFAULT.embedding}"
|
| 85 |
cls._instance = super(EmbeddingModel, cls).__new__(cls)
|
| 86 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 87 |
cls._instance.embedding_model = SentenceTransformer(
|
| 88 |
+
model_name_or_path=get_dir(local_dir),
|
| 89 |
device=device,
|
| 90 |
+
trust_remote_code=True,
|
| 91 |
)
|
| 92 |
print(f"==== using device {device} ====")
|
| 93 |
return cls._instance
|
| 94 |
|
| 95 |
+
|
| 96 |
def get_embedding_model(config):
|
| 97 |
+
print("=== get embedding model ===")
|
| 98 |
+
check_embedding(config.DEFAULT.embedding)
|
| 99 |
return EmbeddingModel(config).embedding_model
|
| 100 |
|
| 101 |
|
src/utils/llms_api.py
CHANGED
|
@@ -49,7 +49,10 @@ class APIHelper(object):
|
|
| 49 |
def get_helper(self):
|
| 50 |
MODEL_TYPE = os.environ["MODEL_TYPE"]
|
| 51 |
MODEL_NAME = os.environ["MODEL_NAME"]
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 53 |
BASE_URL = os.environ["BASE_URL"]
|
| 54 |
return HelperCompany.get()[MODEL_TYPE](
|
| 55 |
MODEL_API_KEY, MODEL_NAME, BASE_URL, timeout=None
|
|
@@ -64,6 +67,8 @@ class APIHelper(object):
|
|
| 64 |
"glm4-air",
|
| 65 |
"qwen-max",
|
| 66 |
"qwen-plus",
|
|
|
|
|
|
|
| 67 |
]:
|
| 68 |
raise ValueError(f"Check model name...")
|
| 69 |
|
|
@@ -78,13 +83,13 @@ class APIHelper(object):
|
|
| 78 |
response1 = self.generator.create(
|
| 79 |
messages=message,
|
| 80 |
)
|
| 81 |
-
summary = clean_text(response1
|
| 82 |
message.append({"role": "assistant", "content": summary})
|
| 83 |
message.append(self.prompt.queries[1][0]())
|
| 84 |
response2 = self.generator.create(
|
| 85 |
messages=message,
|
| 86 |
)
|
| 87 |
-
detail = response2
|
| 88 |
motivation = clean_text(detail.split(TAG_moti)[1].split(TAG_contr)[0])
|
| 89 |
contribution = clean_text(detail.split(TAG_contr)[1])
|
| 90 |
result = {
|
|
@@ -116,7 +121,7 @@ class APIHelper(object):
|
|
| 116 |
response = self.generator.create(
|
| 117 |
messages=message,
|
| 118 |
)
|
| 119 |
-
entities = response
|
| 120 |
entity_list = entities.strip().split(", ")
|
| 121 |
clean_entity_list = []
|
| 122 |
for entity in entity_list:
|
|
@@ -151,7 +156,7 @@ class APIHelper(object):
|
|
| 151 |
response_brainstorming = self.generator.create(
|
| 152 |
messages=message,
|
| 153 |
)
|
| 154 |
-
brainstorming_ideas = response_brainstorming
|
| 155 |
|
| 156 |
except Exception:
|
| 157 |
traceback.print_exc()
|
|
@@ -178,7 +183,7 @@ class APIHelper(object):
|
|
| 178 |
response = self.generator.create(
|
| 179 |
messages=message,
|
| 180 |
)
|
| 181 |
-
problem = response
|
| 182 |
except Exception:
|
| 183 |
traceback.print_exc()
|
| 184 |
return None
|
|
@@ -207,7 +212,7 @@ class APIHelper(object):
|
|
| 207 |
response = self.generator.create(
|
| 208 |
messages=message,
|
| 209 |
)
|
| 210 |
-
problem = response
|
| 211 |
except Exception:
|
| 212 |
traceback.print_exc()
|
| 213 |
return None
|
|
@@ -228,7 +233,7 @@ class APIHelper(object):
|
|
| 228 |
response = self.generator.create(
|
| 229 |
messages=message,
|
| 230 |
)
|
| 231 |
-
inspiration = response
|
| 232 |
except Exception:
|
| 233 |
traceback.print_exc()
|
| 234 |
return None
|
|
@@ -254,7 +259,7 @@ class APIHelper(object):
|
|
| 254 |
response = self.generator.create(
|
| 255 |
messages=message,
|
| 256 |
)
|
| 257 |
-
inspiration = response
|
| 258 |
except Exception:
|
| 259 |
traceback.print_exc()
|
| 260 |
return None
|
|
@@ -282,7 +287,7 @@ class APIHelper(object):
|
|
| 282 |
response = self.generator.create(
|
| 283 |
messages=message,
|
| 284 |
)
|
| 285 |
-
idea = response
|
| 286 |
except Exception:
|
| 287 |
traceback.print_exc()
|
| 288 |
return None
|
|
@@ -314,7 +319,7 @@ class APIHelper(object):
|
|
| 314 |
response = self.generator.create(
|
| 315 |
messages=message,
|
| 316 |
)
|
| 317 |
-
idea = response
|
| 318 |
except Exception:
|
| 319 |
traceback.print_exc()
|
| 320 |
return None
|
|
@@ -340,7 +345,7 @@ class APIHelper(object):
|
|
| 340 |
response = self.generator.create(
|
| 341 |
messages=message,
|
| 342 |
)
|
| 343 |
-
idea = response
|
| 344 |
except Exception:
|
| 345 |
traceback.print_exc()
|
| 346 |
return None
|
|
@@ -372,7 +377,7 @@ class APIHelper(object):
|
|
| 372 |
response = self.generator.create(
|
| 373 |
messages=message,
|
| 374 |
)
|
| 375 |
-
idea = response
|
| 376 |
except Exception:
|
| 377 |
traceback.print_exc()
|
| 378 |
return None
|
|
@@ -391,7 +396,7 @@ class APIHelper(object):
|
|
| 391 |
response = self.generator.create(
|
| 392 |
messages=message,
|
| 393 |
)
|
| 394 |
-
idea = response
|
| 395 |
except Exception:
|
| 396 |
traceback.print_exc()
|
| 397 |
return None
|
|
@@ -413,7 +418,7 @@ class APIHelper(object):
|
|
| 413 |
response = self.generator.create(
|
| 414 |
messages=message,
|
| 415 |
)
|
| 416 |
-
idea_filtered = response
|
| 417 |
except Exception:
|
| 418 |
traceback.print_exc()
|
| 419 |
return None
|
|
@@ -435,7 +440,7 @@ class APIHelper(object):
|
|
| 435 |
response = self.generator.create(
|
| 436 |
messages=message,
|
| 437 |
)
|
| 438 |
-
idea_modified = response
|
| 439 |
except Exception:
|
| 440 |
traceback.print_exc()
|
| 441 |
return None
|
|
@@ -454,7 +459,7 @@ class APIHelper(object):
|
|
| 454 |
response = self.generator.create(
|
| 455 |
messages=message,
|
| 456 |
)
|
| 457 |
-
ground_truth = response
|
| 458 |
except Exception:
|
| 459 |
traceback.print_exc()
|
| 460 |
return ground_truth
|
|
@@ -469,7 +474,7 @@ class APIHelper(object):
|
|
| 469 |
response = self.generator.create(
|
| 470 |
messages=message,
|
| 471 |
)
|
| 472 |
-
idea_norm = response
|
| 473 |
except Exception:
|
| 474 |
traceback.print_exc()
|
| 475 |
return None
|
|
@@ -492,7 +497,7 @@ class APIHelper(object):
|
|
| 492 |
messages=message,
|
| 493 |
max_tokens=10,
|
| 494 |
)
|
| 495 |
-
index = response
|
| 496 |
except Exception:
|
| 497 |
traceback.print_exc()
|
| 498 |
return None
|
|
@@ -509,7 +514,7 @@ class APIHelper(object):
|
|
| 509 |
messages=message,
|
| 510 |
max_tokens=10,
|
| 511 |
)
|
| 512 |
-
score = response
|
| 513 |
except Exception:
|
| 514 |
traceback.print_exc()
|
| 515 |
return None
|
|
@@ -548,7 +553,7 @@ class APIHelper(object):
|
|
| 548 |
stop=None,
|
| 549 |
seed=0,
|
| 550 |
)
|
| 551 |
-
content = response
|
| 552 |
new_msg_history = new_msg_history + [
|
| 553 |
{"role": "assistant", "content": content}
|
| 554 |
]
|
|
@@ -577,7 +582,7 @@ class APIHelper(object):
|
|
| 577 |
response = self.generator.create(
|
| 578 |
messages=message,
|
| 579 |
)
|
| 580 |
-
result = response
|
| 581 |
except Exception:
|
| 582 |
traceback.print_exc()
|
| 583 |
return None
|
|
@@ -601,7 +606,7 @@ class APIHelper(object):
|
|
| 601 |
response = self.generator.create(
|
| 602 |
messages=message,
|
| 603 |
)
|
| 604 |
-
result = response
|
| 605 |
except Exception:
|
| 606 |
traceback.print_exc()
|
| 607 |
return None
|
|
@@ -625,7 +630,7 @@ class APIHelper(object):
|
|
| 625 |
response = self.generator.create(
|
| 626 |
messages=message,
|
| 627 |
)
|
| 628 |
-
result = response
|
| 629 |
except Exception:
|
| 630 |
traceback.print_exc()
|
| 631 |
return None
|
|
@@ -649,7 +654,7 @@ class APIHelper(object):
|
|
| 649 |
response = self.generator.create(
|
| 650 |
messages=message,
|
| 651 |
)
|
| 652 |
-
result = response
|
| 653 |
except Exception:
|
| 654 |
traceback.print_exc()
|
| 655 |
return None
|
|
@@ -673,7 +678,7 @@ class APIHelper(object):
|
|
| 673 |
response = self.generator.create(
|
| 674 |
messages=message,
|
| 675 |
)
|
| 676 |
-
result = response
|
| 677 |
except Exception:
|
| 678 |
traceback.print_exc()
|
| 679 |
return None
|
|
|
|
| 49 |
def get_helper(self):
|
| 50 |
MODEL_TYPE = os.environ["MODEL_TYPE"]
|
| 51 |
MODEL_NAME = os.environ["MODEL_NAME"]
|
| 52 |
+
if MODEL_NAME != "local":
|
| 53 |
+
MODEL_API_KEY = os.environ["MODEL_API_KEY"]
|
| 54 |
+
else:
|
| 55 |
+
MODEL_API_KEY = ""
|
| 56 |
BASE_URL = os.environ["BASE_URL"]
|
| 57 |
return HelperCompany.get()[MODEL_TYPE](
|
| 58 |
MODEL_API_KEY, MODEL_NAME, BASE_URL, timeout=None
|
|
|
|
| 67 |
"glm4-air",
|
| 68 |
"qwen-max",
|
| 69 |
"qwen-plus",
|
| 70 |
+
"gpt-4o-mini",
|
| 71 |
+
"local",
|
| 72 |
]:
|
| 73 |
raise ValueError(f"Check model name...")
|
| 74 |
|
|
|
|
| 83 |
response1 = self.generator.create(
|
| 84 |
messages=message,
|
| 85 |
)
|
| 86 |
+
summary = clean_text(response1)
|
| 87 |
message.append({"role": "assistant", "content": summary})
|
| 88 |
message.append(self.prompt.queries[1][0]())
|
| 89 |
response2 = self.generator.create(
|
| 90 |
messages=message,
|
| 91 |
)
|
| 92 |
+
detail = response2
|
| 93 |
motivation = clean_text(detail.split(TAG_moti)[1].split(TAG_contr)[0])
|
| 94 |
contribution = clean_text(detail.split(TAG_contr)[1])
|
| 95 |
result = {
|
|
|
|
| 121 |
response = self.generator.create(
|
| 122 |
messages=message,
|
| 123 |
)
|
| 124 |
+
entities = response
|
| 125 |
entity_list = entities.strip().split(", ")
|
| 126 |
clean_entity_list = []
|
| 127 |
for entity in entity_list:
|
|
|
|
| 156 |
response_brainstorming = self.generator.create(
|
| 157 |
messages=message,
|
| 158 |
)
|
| 159 |
+
brainstorming_ideas = response_brainstorming
|
| 160 |
|
| 161 |
except Exception:
|
| 162 |
traceback.print_exc()
|
|
|
|
| 183 |
response = self.generator.create(
|
| 184 |
messages=message,
|
| 185 |
)
|
| 186 |
+
problem = response
|
| 187 |
except Exception:
|
| 188 |
traceback.print_exc()
|
| 189 |
return None
|
|
|
|
| 212 |
response = self.generator.create(
|
| 213 |
messages=message,
|
| 214 |
)
|
| 215 |
+
problem = response
|
| 216 |
except Exception:
|
| 217 |
traceback.print_exc()
|
| 218 |
return None
|
|
|
|
| 233 |
response = self.generator.create(
|
| 234 |
messages=message,
|
| 235 |
)
|
| 236 |
+
inspiration = response
|
| 237 |
except Exception:
|
| 238 |
traceback.print_exc()
|
| 239 |
return None
|
|
|
|
| 259 |
response = self.generator.create(
|
| 260 |
messages=message,
|
| 261 |
)
|
| 262 |
+
inspiration = response
|
| 263 |
except Exception:
|
| 264 |
traceback.print_exc()
|
| 265 |
return None
|
|
|
|
| 287 |
response = self.generator.create(
|
| 288 |
messages=message,
|
| 289 |
)
|
| 290 |
+
idea = response
|
| 291 |
except Exception:
|
| 292 |
traceback.print_exc()
|
| 293 |
return None
|
|
|
|
| 319 |
response = self.generator.create(
|
| 320 |
messages=message,
|
| 321 |
)
|
| 322 |
+
idea = response
|
| 323 |
except Exception:
|
| 324 |
traceback.print_exc()
|
| 325 |
return None
|
|
|
|
| 345 |
response = self.generator.create(
|
| 346 |
messages=message,
|
| 347 |
)
|
| 348 |
+
idea = response
|
| 349 |
except Exception:
|
| 350 |
traceback.print_exc()
|
| 351 |
return None
|
|
|
|
| 377 |
response = self.generator.create(
|
| 378 |
messages=message,
|
| 379 |
)
|
| 380 |
+
idea = response
|
| 381 |
except Exception:
|
| 382 |
traceback.print_exc()
|
| 383 |
return None
|
|
|
|
| 396 |
response = self.generator.create(
|
| 397 |
messages=message,
|
| 398 |
)
|
| 399 |
+
idea = response
|
| 400 |
except Exception:
|
| 401 |
traceback.print_exc()
|
| 402 |
return None
|
|
|
|
| 418 |
response = self.generator.create(
|
| 419 |
messages=message,
|
| 420 |
)
|
| 421 |
+
idea_filtered = response
|
| 422 |
except Exception:
|
| 423 |
traceback.print_exc()
|
| 424 |
return None
|
|
|
|
| 440 |
response = self.generator.create(
|
| 441 |
messages=message,
|
| 442 |
)
|
| 443 |
+
idea_modified = response
|
| 444 |
except Exception:
|
| 445 |
traceback.print_exc()
|
| 446 |
return None
|
|
|
|
| 459 |
response = self.generator.create(
|
| 460 |
messages=message,
|
| 461 |
)
|
| 462 |
+
ground_truth = response
|
| 463 |
except Exception:
|
| 464 |
traceback.print_exc()
|
| 465 |
return ground_truth
|
|
|
|
| 474 |
response = self.generator.create(
|
| 475 |
messages=message,
|
| 476 |
)
|
| 477 |
+
idea_norm = response
|
| 478 |
except Exception:
|
| 479 |
traceback.print_exc()
|
| 480 |
return None
|
|
|
|
| 497 |
messages=message,
|
| 498 |
max_tokens=10,
|
| 499 |
)
|
| 500 |
+
index = response
|
| 501 |
except Exception:
|
| 502 |
traceback.print_exc()
|
| 503 |
return None
|
|
|
|
| 514 |
messages=message,
|
| 515 |
max_tokens=10,
|
| 516 |
)
|
| 517 |
+
score = response
|
| 518 |
except Exception:
|
| 519 |
traceback.print_exc()
|
| 520 |
return None
|
|
|
|
| 553 |
stop=None,
|
| 554 |
seed=0,
|
| 555 |
)
|
| 556 |
+
content = response
|
| 557 |
new_msg_history = new_msg_history + [
|
| 558 |
{"role": "assistant", "content": content}
|
| 559 |
]
|
|
|
|
| 582 |
response = self.generator.create(
|
| 583 |
messages=message,
|
| 584 |
)
|
| 585 |
+
result = response
|
| 586 |
except Exception:
|
| 587 |
traceback.print_exc()
|
| 588 |
return None
|
|
|
|
| 606 |
response = self.generator.create(
|
| 607 |
messages=message,
|
| 608 |
)
|
| 609 |
+
result = response
|
| 610 |
except Exception:
|
| 611 |
traceback.print_exc()
|
| 612 |
return None
|
|
|
|
| 630 |
response = self.generator.create(
|
| 631 |
messages=message,
|
| 632 |
)
|
| 633 |
+
result = response
|
| 634 |
except Exception:
|
| 635 |
traceback.print_exc()
|
| 636 |
return None
|
|
|
|
| 654 |
response = self.generator.create(
|
| 655 |
messages=message,
|
| 656 |
)
|
| 657 |
+
result = response
|
| 658 |
except Exception:
|
| 659 |
traceback.print_exc()
|
| 660 |
return None
|
|
|
|
| 678 |
response = self.generator.create(
|
| 679 |
messages=message,
|
| 680 |
)
|
| 681 |
+
result = response
|
| 682 |
except Exception:
|
| 683 |
traceback.print_exc()
|
| 684 |
return None
|
src/utils/paper_client.py
CHANGED
|
@@ -8,6 +8,7 @@ from collections import defaultdict, deque
|
|
| 8 |
from py2neo import Graph, Node, Relationship
|
| 9 |
from loguru import logger
|
| 10 |
|
|
|
|
| 11 |
class PaperClient:
|
| 12 |
_instance = None
|
| 13 |
_initialized = False
|
|
@@ -43,10 +44,28 @@ class PaperClient:
|
|
| 43 |
with self.driver.session() as session:
|
| 44 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
| 45 |
if result:
|
| 46 |
-
paper_from_client = result[0][
|
| 47 |
if paper_from_client is not None:
|
| 48 |
paper.update(paper_from_client)
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def get_paper_attribute(self, paper_id, attribute_name):
|
| 51 |
query = f"""
|
| 52 |
MATCH (p:Paper {{hash_id: {paper_id}}})
|
|
@@ -55,11 +74,11 @@ class PaperClient:
|
|
| 55 |
with self.driver.session() as session:
|
| 56 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
| 57 |
if result:
|
| 58 |
-
return result[0][
|
| 59 |
else:
|
| 60 |
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
| 61 |
return None
|
| 62 |
-
|
| 63 |
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
| 64 |
query = f"""
|
| 65 |
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
|
@@ -68,7 +87,7 @@ class PaperClient:
|
|
| 68 |
with self.driver.session() as session:
|
| 69 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
| 70 |
if result:
|
| 71 |
-
return result[0][
|
| 72 |
else:
|
| 73 |
return None
|
| 74 |
|
|
@@ -81,71 +100,50 @@ class PaperClient:
|
|
| 81 |
RETURN p.hash_id as hash_id
|
| 82 |
"""
|
| 83 |
with self.driver.session() as session:
|
| 84 |
-
result = session.execute_read(
|
|
|
|
|
|
|
| 85 |
if result:
|
| 86 |
-
return [record[
|
| 87 |
else:
|
| 88 |
return []
|
| 89 |
-
|
| 90 |
-
def
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
visited = set([entity_name])
|
| 95 |
-
related_entities = set()
|
| 96 |
-
|
| 97 |
-
while queue:
|
| 98 |
-
batch_queue = [queue.popleft() for _ in range(len(queue))]
|
| 99 |
-
batch_entities = [item[0] for item in batch_queue]
|
| 100 |
-
batch_depths = [item[1] for item in batch_queue]
|
| 101 |
-
|
| 102 |
-
if all(depth >= n for depth in batch_depths):
|
| 103 |
-
continue
|
| 104 |
-
if relation_name == "related":
|
| 105 |
-
query = """
|
| 106 |
-
UNWIND $batch_entities AS entity_name
|
| 107 |
-
MATCH (e1:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)<-[:RELATED_TO]-(e2:Entity)
|
| 108 |
-
WHERE e1 <> e2
|
| 109 |
-
WITH e1, e2, COUNT(p) AS common_papers, entity_name
|
| 110 |
-
WHERE common_papers > $k
|
| 111 |
-
RETURN e2.name AS entities, entity_name AS source_entity, common_papers
|
| 112 |
-
"""
|
| 113 |
-
elif relation_name == "connect":
|
| 114 |
-
query = """
|
| 115 |
-
UNWIND $batch_entities AS entity_name
|
| 116 |
-
MATCH (e1:Entity {name: entity_name})-[r:CONNECT]-(e2:Entity)
|
| 117 |
-
WHERE e1 <> e2 and r.strength >= $k
|
| 118 |
-
WITH e1, e2, entity_name
|
| 119 |
-
RETURN e2.name AS entities, entity_name AS source_entity
|
| 120 |
-
"""
|
| 121 |
-
with self.driver.session() as session:
|
| 122 |
-
result = session.execute_read(lambda tx: tx.run(query, batch_entities=batch_entities, k=k).data())
|
| 123 |
-
|
| 124 |
-
for record in result:
|
| 125 |
-
entity = record['entities']
|
| 126 |
-
source_entity = record['source_entity']
|
| 127 |
-
if entity not in visited:
|
| 128 |
-
visited.add(entity)
|
| 129 |
-
queue.append((entity, batch_depths[batch_entities.index(source_entity)] + 1))
|
| 130 |
-
related_entities.add(entity)
|
| 131 |
-
|
| 132 |
-
return list(related_entities)
|
| 133 |
-
related_entities = bfs_query(entity_name, n, k)
|
| 134 |
-
if entity_name in related_entities:
|
| 135 |
-
related_entities.remove(entity_name)
|
| 136 |
-
return related_entities
|
| 137 |
-
|
| 138 |
-
def find_entities_by_paper(self, hash_id: int):
|
| 139 |
query = """
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
"""
|
| 143 |
with self.driver.session() as session:
|
| 144 |
-
result = session.execute_read(
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
def find_paper_by_entity(self, entity_name):
|
| 151 |
query = """
|
|
@@ -153,18 +151,19 @@ class PaperClient:
|
|
| 153 |
RETURN p.hash_id AS hash_id
|
| 154 |
"""
|
| 155 |
with self.driver.session() as session:
|
| 156 |
-
result = session.execute_read(
|
|
|
|
|
|
|
| 157 |
if result:
|
| 158 |
-
return [record[
|
| 159 |
else:
|
| 160 |
return []
|
| 161 |
-
|
| 162 |
# TODO: @云翔
|
| 163 |
# 增加通过entity返回包含entity语句的功能
|
| 164 |
def find_sentence_by_entity(self, entity_name):
|
| 165 |
# Return: list(str)
|
| 166 |
return []
|
| 167 |
-
|
| 168 |
|
| 169 |
def find_sentences_by_entity(self, entity_name):
|
| 170 |
query = """
|
|
@@ -178,14 +177,25 @@ class PaperClient:
|
|
| 178 |
p.hash_id AS hash_id
|
| 179 |
"""
|
| 180 |
sentences = []
|
| 181 |
-
|
| 182 |
with self.driver.session() as session:
|
| 183 |
-
result = session.execute_read(
|
|
|
|
|
|
|
| 184 |
for record in result:
|
| 185 |
-
for key in [
|
| 186 |
if record[key]:
|
| 187 |
-
filtered_sentences = [
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
return sentences
|
| 191 |
|
|
@@ -194,9 +204,11 @@ class PaperClient:
|
|
| 194 |
MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
|
| 195 |
"""
|
| 196 |
with self.driver.session() as session:
|
| 197 |
-
result = session.execute_read(
|
|
|
|
|
|
|
| 198 |
if result:
|
| 199 |
-
return [record[
|
| 200 |
else:
|
| 201 |
return []
|
| 202 |
|
|
@@ -230,7 +242,26 @@ class PaperClient:
|
|
| 230 |
RETURN p
|
| 231 |
"""
|
| 232 |
with self.driver.session() as session:
|
| 233 |
-
result = session.execute_write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
def check_entity_node_count(self, hash_id: int):
|
| 236 |
query_check_count = """
|
|
@@ -239,7 +270,9 @@ class PaperClient:
|
|
| 239 |
"""
|
| 240 |
with self.driver.session() as session:
|
| 241 |
# Check the number of related entities
|
| 242 |
-
result = session.execute_read(
|
|
|
|
|
|
|
| 243 |
if result[0]["entity_count"] > 3:
|
| 244 |
return False
|
| 245 |
return True
|
|
@@ -254,16 +287,30 @@ class PaperClient:
|
|
| 254 |
"""
|
| 255 |
with self.driver.session() as session:
|
| 256 |
for entity_name in entities:
|
| 257 |
-
result = session.execute_write(
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
def add_paper_citation(self, paper: dict):
|
| 260 |
query = """
|
| 261 |
MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
|
| 262 |
"""
|
| 263 |
with self.driver.session() as session:
|
| 264 |
-
result = session.execute_write(
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
if hash_id is not None:
|
| 268 |
query = """
|
| 269 |
MATCH (p:Paper {hash_id: $hash_id})
|
|
@@ -271,119 +318,302 @@ class PaperClient:
|
|
| 271 |
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
| 272 |
"""
|
| 273 |
with self.driver.session() as session:
|
| 274 |
-
results = session.execute_write(
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
MATCH (p:Paper)
|
| 278 |
WHERE p.abstract IS NOT NULL
|
| 279 |
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
|
|
|
| 280 |
"""
|
| 281 |
with self.driver.session() as session:
|
| 282 |
-
results = session.execute_write(
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
with self.driver.session() as session:
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
-
def add_paper_bg_embedding(self, embedding_model, hash_id=None):
|
| 297 |
if hash_id is not None:
|
| 298 |
query = """
|
| 299 |
MATCH (p:Paper {hash_id: $hash_id})
|
| 300 |
WHERE p.motivation IS NOT NULL
|
| 301 |
-
RETURN p.motivation AS context, p.hash_id AS hash_id
|
| 302 |
"""
|
| 303 |
with self.driver.session() as session:
|
| 304 |
-
results = session.execute_write(
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
MATCH (p:Paper)
|
| 308 |
WHERE p.motivation IS NOT NULL
|
| 309 |
-
RETURN p.motivation AS context, p.hash_id AS hash_id
|
|
|
|
| 310 |
"""
|
| 311 |
with self.driver.session() as session:
|
| 312 |
-
results = session.execute_write(
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
with self.driver.session() as session:
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
if hash_id is not None:
|
| 328 |
query = """
|
| 329 |
MATCH (p:Paper {hash_id: $hash_id})
|
| 330 |
WHERE p.contribution IS NOT NULL
|
| 331 |
-
RETURN p.contribution AS context, p.hash_id AS hash_id
|
| 332 |
"""
|
| 333 |
with self.driver.session() as session:
|
| 334 |
-
results = session.execute_write(
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
MATCH (p:Paper)
|
| 338 |
WHERE p.contribution IS NOT NULL
|
| 339 |
-
RETURN p.contribution AS context, p.hash_id AS hash_id
|
|
|
|
| 340 |
"""
|
| 341 |
with self.driver.session() as session:
|
| 342 |
-
results = session.execute_write(
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
with self.driver.session() as session:
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
if hash_id is not None:
|
| 359 |
query = """
|
| 360 |
MATCH (p:Paper {hash_id: $hash_id})
|
| 361 |
WHERE p.summary IS NOT NULL
|
| 362 |
-
RETURN p.summary AS context, p.hash_id AS hash_id
|
| 363 |
"""
|
| 364 |
with self.driver.session() as session:
|
| 365 |
-
results = session.execute_write(
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
MATCH (p:Paper)
|
| 369 |
WHERE p.summary IS NOT NULL
|
| 370 |
-
RETURN p.summary AS context, p.hash_id AS hash_id
|
|
|
|
| 371 |
"""
|
| 372 |
with self.driver.session() as session:
|
| 373 |
-
results = session.execute_write(
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
with self.driver.session() as session:
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
|
| 388 |
query = f"""
|
| 389 |
MATCH (paper:Paper)
|
|
@@ -394,8 +624,10 @@ class PaperClient:
|
|
| 394 |
ORDER BY score DESC LIMIT {k}
|
| 395 |
"""
|
| 396 |
with self.driver.session() as session:
|
| 397 |
-
results = session.execute_read(
|
| 398 |
-
|
|
|
|
|
|
|
| 399 |
for result in results:
|
| 400 |
related_paper.append(result["paper"]["hash_id"])
|
| 401 |
return related_paper
|
|
@@ -417,7 +649,7 @@ class PaperClient:
|
|
| 417 |
"""
|
| 418 |
with self.driver.session() as session:
|
| 419 |
session.execute_write(lambda tx: tx.run(query).data())
|
| 420 |
-
|
| 421 |
def filter_paper_id_list(self, paper_id_list, year="2024"):
|
| 422 |
if not paper_id_list:
|
| 423 |
return []
|
|
@@ -429,12 +661,14 @@ class PaperClient:
|
|
| 429 |
RETURN p.hash_id AS hash_id
|
| 430 |
"""
|
| 431 |
with self.driver.session() as session:
|
| 432 |
-
result = session.execute_read(
|
| 433 |
-
|
| 434 |
-
|
|
|
|
|
|
|
| 435 |
existing_paper_ids = list(set(existing_paper_ids))
|
| 436 |
return existing_paper_ids
|
| 437 |
-
|
| 438 |
def check_index_exists(self):
|
| 439 |
query = "SHOW INDEXES"
|
| 440 |
with self.driver.session() as session:
|
|
@@ -451,7 +685,7 @@ class PaperClient:
|
|
| 451 |
"""
|
| 452 |
with self.driver.session() as session:
|
| 453 |
session.execute_write(lambda tx: tx.run(query).data())
|
| 454 |
-
|
| 455 |
def get_entity_related_paper_num(self, entity_name):
|
| 456 |
query = """
|
| 457 |
MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
|
|
@@ -459,10 +693,30 @@ class PaperClient:
|
|
| 459 |
RETURN PaperCount
|
| 460 |
"""
|
| 461 |
with self.driver.session() as session:
|
| 462 |
-
result = session.execute_read(
|
| 463 |
-
|
|
|
|
|
|
|
| 464 |
return paper_num
|
| 465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
def get_entity_text(self):
|
| 467 |
query = """
|
| 468 |
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
|
|
@@ -472,11 +726,13 @@ class PaperClient:
|
|
| 472 |
"""
|
| 473 |
with self.driver.session() as session:
|
| 474 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
| 475 |
-
text_list = [record[
|
| 476 |
return text_list
|
| 477 |
-
|
| 478 |
def get_entity_combinations(self, venue_name, year):
|
| 479 |
-
def process_paper_relationships(
|
|
|
|
|
|
|
| 480 |
if entity_name_2 < entity_name_1:
|
| 481 |
entity_name_1, entity_name_2 = entity_name_2, entity_name_1
|
| 482 |
query = """
|
|
@@ -486,13 +742,17 @@ class PaperClient:
|
|
| 486 |
ON CREATE SET r.strength = 1
|
| 487 |
ON MATCH SET r.strength = r.strength + 1
|
| 488 |
"""
|
| 489 |
-
sentences = re.split(r
|
| 490 |
for sentence in sentences:
|
| 491 |
sentence = sentence.lower()
|
| 492 |
if entity_name_1 in sentence and entity_name_2 in sentence:
|
| 493 |
# 如果两个实体在同一句话中出现过,则创建或更新 CONNECT 关系
|
| 494 |
session.execute_write(
|
| 495 |
-
lambda tx: tx.run(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
)
|
| 497 |
# logger.debug(f"CONNECT relation created or updated between {entity_name_1} and {entity_name_2} for Paper ID {paper_id}")
|
| 498 |
break # 如果找到一次出现就可以退出循环
|
|
@@ -506,13 +766,17 @@ class PaperClient:
|
|
| 506 |
RETURN p.hash_id AS hash_id, entities[i].name AS entity_name_1, entities[j].name AS entity_name_2
|
| 507 |
"""
|
| 508 |
with self.driver.session() as session:
|
| 509 |
-
result = session.execute_read(
|
|
|
|
|
|
|
| 510 |
for record in tqdm(result):
|
| 511 |
paper_id = record["hash_id"]
|
| 512 |
-
entity_name_1 = record[
|
| 513 |
-
entity_name_2 = record[
|
| 514 |
abstract = self.get_paper_attribute(paper_id, "abstract")
|
| 515 |
-
process_paper_relationships(
|
|
|
|
|
|
|
| 516 |
|
| 517 |
def build_citemap(self):
|
| 518 |
citemap = defaultdict(set)
|
|
@@ -523,8 +787,8 @@ class PaperClient:
|
|
| 523 |
with self.driver.session() as session:
|
| 524 |
results = session.execute_read(lambda tx: tx.run(query).data())
|
| 525 |
for result in results:
|
| 526 |
-
hash_id = result[
|
| 527 |
-
cite_id_list = result[
|
| 528 |
if cite_id_list:
|
| 529 |
for cited_id in cite_id_list:
|
| 530 |
citemap[hash_id].add(cited_id)
|
|
@@ -537,12 +801,17 @@ class PaperClient:
|
|
| 537 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
| 538 |
graph = Graph(URI, auth=AUTH)
|
| 539 |
# 创建一个字典来保存数据
|
|
|
|
| 540 |
data = {"nodes": [], "relationships": []}
|
| 541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
| 543 |
-
WHERE p.venue_name='iclr' and p.year='2024'
|
| 544 |
RETURN p, e, r
|
| 545 |
"""
|
|
|
|
| 546 |
results = graph.run(query)
|
| 547 |
# 处理查询结果
|
| 548 |
for record in tqdm(results):
|
|
@@ -550,39 +819,46 @@ class PaperClient:
|
|
| 550 |
entity_node = record["e"]
|
| 551 |
relationship = record["r"]
|
| 552 |
# 将节点数据加入字典
|
| 553 |
-
data["nodes"].append(
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
# 将关系数据加入字典
|
| 564 |
-
data["relationships"].append(
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
|
|
|
|
|
|
|
|
|
| 570 |
query = """
|
| 571 |
MATCH (p:Paper)
|
| 572 |
WHERE p.venue_name='acl' and p.year='2024'
|
| 573 |
RETURN p
|
| 574 |
"""
|
| 575 |
-
"""
|
| 576 |
results = graph.run(query)
|
| 577 |
for record in tqdm(results):
|
| 578 |
paper_node = record["p"]
|
| 579 |
# 将节点数据加入字典
|
| 580 |
-
data["nodes"].append(
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
|
|
|
| 586 |
# 去除重复节点
|
| 587 |
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
| 588 |
unique_nodes = []
|
|
@@ -595,9 +871,11 @@ class PaperClient:
|
|
| 595 |
unique_nodes.append(node)
|
| 596 |
data["nodes"] = unique_nodes
|
| 597 |
# 将数据保存为 JSON 文件
|
| 598 |
-
with open(
|
|
|
|
|
|
|
| 599 |
json.dump(data, f, ensure_ascii=False, indent=4)
|
| 600 |
-
|
| 601 |
def neo4j_import_data(self):
|
| 602 |
# clear_database() # 清空数据库,谨慎执行
|
| 603 |
URI = os.environ["NEO4J_URL"]
|
|
@@ -606,7 +884,9 @@ class PaperClient:
|
|
| 606 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
| 607 |
graph = Graph(URI, auth=AUTH)
|
| 608 |
# 从 JSON 文件中读取数据
|
| 609 |
-
with open(
|
|
|
|
|
|
|
| 610 |
data = json.load(f)
|
| 611 |
# 创建节点
|
| 612 |
nodes = {}
|
|
|
|
| 8 |
from py2neo import Graph, Node, Relationship
|
| 9 |
from loguru import logger
|
| 10 |
|
| 11 |
+
|
| 12 |
class PaperClient:
|
| 13 |
_instance = None
|
| 14 |
_initialized = False
|
|
|
|
| 44 |
with self.driver.session() as session:
|
| 45 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
| 46 |
if result:
|
| 47 |
+
paper_from_client = result[0]["p"]
|
| 48 |
if paper_from_client is not None:
|
| 49 |
paper.update(paper_from_client)
|
| 50 |
+
|
| 51 |
+
def update_papers_from_client(self, paper_id_list):
|
| 52 |
+
query = """
|
| 53 |
+
UNWIND $papers AS paper
|
| 54 |
+
MATCH (p:Paper {hash_id: paper.hash_id})
|
| 55 |
+
RETURN p as result
|
| 56 |
+
"""
|
| 57 |
+
paper_data = [
|
| 58 |
+
{
|
| 59 |
+
"hash_id": hash_id,
|
| 60 |
+
}
|
| 61 |
+
for hash_id in paper_id_list
|
| 62 |
+
]
|
| 63 |
+
with self.driver.session() as session:
|
| 64 |
+
result = session.execute_read(
|
| 65 |
+
lambda tx: tx.run(query, papers=paper_data).data()
|
| 66 |
+
)
|
| 67 |
+
return [r["result"] for r in result]
|
| 68 |
+
|
| 69 |
def get_paper_attribute(self, paper_id, attribute_name):
|
| 70 |
query = f"""
|
| 71 |
MATCH (p:Paper {{hash_id: {paper_id}}})
|
|
|
|
| 74 |
with self.driver.session() as session:
|
| 75 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
| 76 |
if result:
|
| 77 |
+
return result[0]["attributeValue"]
|
| 78 |
else:
|
| 79 |
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
| 80 |
return None
|
| 81 |
+
|
| 82 |
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
| 83 |
query = f"""
|
| 84 |
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
|
|
|
| 87 |
with self.driver.session() as session:
|
| 88 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
| 89 |
if result:
|
| 90 |
+
return result[0]["p"]
|
| 91 |
else:
|
| 92 |
return None
|
| 93 |
|
|
|
|
| 100 |
RETURN p.hash_id as hash_id
|
| 101 |
"""
|
| 102 |
with self.driver.session() as session:
|
| 103 |
+
result = session.execute_read(
|
| 104 |
+
lambda tx: tx.run(query, entity=entity).data()
|
| 105 |
+
)
|
| 106 |
if result:
|
| 107 |
+
return [record["hash_id"] for record in result]
|
| 108 |
else:
|
| 109 |
return []
|
| 110 |
+
|
| 111 |
+
def find_related_entities_by_entity_list(
|
| 112 |
+
self, entity_names, n=1, k=3, relation_name="related"
|
| 113 |
+
):
|
| 114 |
+
related_entities = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
query = """
|
| 116 |
+
UNWIND $batch_entities AS entity_name
|
| 117 |
+
MATCH (e1:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)<-[:RELATED_TO]-(e2:Entity)
|
| 118 |
+
WHERE e1 <> e2
|
| 119 |
+
WITH e1, e2, COUNT(p) AS common_papers, entity_name
|
| 120 |
+
WHERE common_papers > $k
|
| 121 |
+
RETURN e2.name AS entities, entity_name AS source_entity, common_papers
|
| 122 |
"""
|
| 123 |
with self.driver.session() as session:
|
| 124 |
+
result = session.execute_read(
|
| 125 |
+
lambda tx: tx.run(query, batch_entities=entity_names, k=k).data()
|
| 126 |
+
)
|
| 127 |
+
for record in result:
|
| 128 |
+
entity = record["entities"]
|
| 129 |
+
related_entities.add(entity)
|
| 130 |
+
return list(related_entities)
|
| 131 |
+
|
| 132 |
+
def find_entities_by_paper_list(self, hash_ids: list):
|
| 133 |
+
query = """
|
| 134 |
+
UNWIND $hash_ids AS hash_id
|
| 135 |
+
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: hash_id})
|
| 136 |
+
RETURN hash_id, e.name AS entity_name
|
| 137 |
+
"""
|
| 138 |
+
with self.driver.session() as session:
|
| 139 |
+
result = session.execute_read(
|
| 140 |
+
lambda tx: tx.run(query, hash_ids=hash_ids).data()
|
| 141 |
+
)
|
| 142 |
+
# 按照每个 hash_id 分组实体
|
| 143 |
+
entity_list = []
|
| 144 |
+
for record in result:
|
| 145 |
+
entity_list.append(record["entity_name"])
|
| 146 |
+
return entity_list
|
| 147 |
|
| 148 |
def find_paper_by_entity(self, entity_name):
|
| 149 |
query = """
|
|
|
|
| 151 |
RETURN p.hash_id AS hash_id
|
| 152 |
"""
|
| 153 |
with self.driver.session() as session:
|
| 154 |
+
result = session.execute_read(
|
| 155 |
+
lambda tx: tx.run(query, entity_name=entity_name).data()
|
| 156 |
+
)
|
| 157 |
if result:
|
| 158 |
+
return [record["hash_id"] for record in result]
|
| 159 |
else:
|
| 160 |
return []
|
| 161 |
+
|
| 162 |
# TODO: @云翔
|
| 163 |
# 增加通过entity返回包含entity语句的功能
|
| 164 |
def find_sentence_by_entity(self, entity_name):
|
| 165 |
# Return: list(str)
|
| 166 |
return []
|
|
|
|
| 167 |
|
| 168 |
def find_sentences_by_entity(self, entity_name):
|
| 169 |
query = """
|
|
|
|
| 177 |
p.hash_id AS hash_id
|
| 178 |
"""
|
| 179 |
sentences = []
|
| 180 |
+
|
| 181 |
with self.driver.session() as session:
|
| 182 |
+
result = session.execute_read(
|
| 183 |
+
lambda tx: tx.run(query, entity_name=entity_name).data()
|
| 184 |
+
)
|
| 185 |
for record in result:
|
| 186 |
+
for key in ["abstract", "introduction", "methodology"]:
|
| 187 |
if record[key]:
|
| 188 |
+
filtered_sentences = [
|
| 189 |
+
sentence.strip() + "."
|
| 190 |
+
for sentence in record[key].split(".")
|
| 191 |
+
if entity_name in sentence
|
| 192 |
+
]
|
| 193 |
+
sentences.extend(
|
| 194 |
+
[
|
| 195 |
+
f"{record['hash_id']}: {sentence}"
|
| 196 |
+
for sentence in filtered_sentences
|
| 197 |
+
]
|
| 198 |
+
)
|
| 199 |
|
| 200 |
return sentences
|
| 201 |
|
|
|
|
| 204 |
MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
|
| 205 |
"""
|
| 206 |
with self.driver.session() as session:
|
| 207 |
+
result = session.execute_read(
|
| 208 |
+
lambda tx: tx.run(query, year=year, venue_name=venue_name).data()
|
| 209 |
+
)
|
| 210 |
if result:
|
| 211 |
+
return [record["n"] for record in result]
|
| 212 |
else:
|
| 213 |
return []
|
| 214 |
|
|
|
|
| 242 |
RETURN p
|
| 243 |
"""
|
| 244 |
with self.driver.session() as session:
|
| 245 |
+
result = session.execute_write(
|
| 246 |
+
lambda tx: tx.run(
|
| 247 |
+
query,
|
| 248 |
+
hash_id=paper["hash_id"],
|
| 249 |
+
venue_name=paper["venue_name"],
|
| 250 |
+
year=paper["year"],
|
| 251 |
+
title=paper["title"],
|
| 252 |
+
pdf_url=paper["pdf_url"],
|
| 253 |
+
abstract=paper["abstract"],
|
| 254 |
+
introduction=paper["introduction"],
|
| 255 |
+
reference=paper["reference"],
|
| 256 |
+
summary=paper["summary"],
|
| 257 |
+
motivation=paper["motivation"],
|
| 258 |
+
contribution=paper["contribution"],
|
| 259 |
+
methodology=paper["methodology"],
|
| 260 |
+
ground_truth=paper["ground_truth"],
|
| 261 |
+
reference_filter=paper["reference_filter"],
|
| 262 |
+
conclusions=paper["conclusions"],
|
| 263 |
+
).data()
|
| 264 |
+
)
|
| 265 |
|
| 266 |
def check_entity_node_count(self, hash_id: int):
|
| 267 |
query_check_count = """
|
|
|
|
| 270 |
"""
|
| 271 |
with self.driver.session() as session:
|
| 272 |
# Check the number of related entities
|
| 273 |
+
result = session.execute_read(
|
| 274 |
+
lambda tx: tx.run(query_check_count, hash_id=hash_id).data()
|
| 275 |
+
)
|
| 276 |
if result[0]["entity_count"] > 3:
|
| 277 |
return False
|
| 278 |
return True
|
|
|
|
| 287 |
"""
|
| 288 |
with self.driver.session() as session:
|
| 289 |
for entity_name in entities:
|
| 290 |
+
result = session.execute_write(
|
| 291 |
+
lambda tx: tx.run(
|
| 292 |
+
query, entity_name=entity_name, hash_id=hash_id
|
| 293 |
+
).data()
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
def add_paper_citation(self, paper: dict):
|
| 297 |
query = """
|
| 298 |
MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
|
| 299 |
"""
|
| 300 |
with self.driver.session() as session:
|
| 301 |
+
result = session.execute_write(
|
| 302 |
+
lambda tx: tx.run(
|
| 303 |
+
query,
|
| 304 |
+
hash_id=paper["hash_id"],
|
| 305 |
+
cite_id_list=paper["cite_id_list"],
|
| 306 |
+
entities=paper["entities"],
|
| 307 |
+
all_cite_id_list=paper["all_cite_id_list"],
|
| 308 |
+
).data()
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def add_paper_abstract_embedding(
|
| 312 |
+
self, embedding_model, hash_id=None, batch_size=512
|
| 313 |
+
):
|
| 314 |
if hash_id is not None:
|
| 315 |
query = """
|
| 316 |
MATCH (p:Paper {hash_id: $hash_id})
|
|
|
|
| 318 |
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
| 319 |
"""
|
| 320 |
with self.driver.session() as session:
|
| 321 |
+
results = session.execute_write(
|
| 322 |
+
lambda tx: tx.run(query, hash_id=hash_id).data()
|
| 323 |
+
)
|
| 324 |
+
contexts = [result["title"] + result["context"] for result in results]
|
| 325 |
+
paper_ids = [result["hash_id"] for result in results]
|
| 326 |
+
context_embeddings = embedding_model.encode(
|
| 327 |
+
contexts, convert_to_tensor=True, device=self.device
|
| 328 |
+
)
|
| 329 |
query = """
|
| 330 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
| 331 |
+
ON CREATE SET p.abstract_embedding = $embedding
|
| 332 |
+
ON MATCH SET p.abstract_embedding = $embedding
|
| 333 |
+
"""
|
| 334 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
| 335 |
+
embedding = (
|
| 336 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
| 337 |
+
)
|
| 338 |
+
with self.driver.session() as session:
|
| 339 |
+
results = session.execute_write(
|
| 340 |
+
lambda tx: tx.run(
|
| 341 |
+
query, hash_id=hash_id, embedding=embedding
|
| 342 |
+
).data()
|
| 343 |
+
)
|
| 344 |
+
return
|
| 345 |
+
offset = 0
|
| 346 |
+
while True:
|
| 347 |
+
query = f"""
|
| 348 |
MATCH (p:Paper)
|
| 349 |
WHERE p.abstract IS NOT NULL
|
| 350 |
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
| 351 |
+
SKIP $offset LIMIT $batch_size
|
| 352 |
"""
|
| 353 |
with self.driver.session() as session:
|
| 354 |
+
results = session.execute_write(
|
| 355 |
+
lambda tx: tx.run(
|
| 356 |
+
query, offset=offset, batch_size=batch_size
|
| 357 |
+
).data()
|
| 358 |
+
)
|
| 359 |
+
if not results:
|
| 360 |
+
break
|
| 361 |
+
contexts = [result["title"] + result["context"] for result in results]
|
| 362 |
+
paper_ids = [result["hash_id"] for result in results]
|
| 363 |
+
context_embeddings = embedding_model.encode(
|
| 364 |
+
contexts,
|
| 365 |
+
batch_size=batch_size,
|
| 366 |
+
convert_to_tensor=True,
|
| 367 |
+
device=self.device,
|
| 368 |
+
)
|
| 369 |
+
write_query = """
|
| 370 |
+
UNWIND $data AS row
|
| 371 |
+
MERGE (p:Paper {hash_id: row.hash_id})
|
| 372 |
+
ON CREATE SET p.abstract_embedding = row.embedding
|
| 373 |
+
ON MATCH SET p.abstract_embedding = row.embedding
|
| 374 |
+
"""
|
| 375 |
+
data_to_write = []
|
| 376 |
+
for idx, hash_id in enumerate(paper_ids):
|
| 377 |
+
embedding = (
|
| 378 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
| 379 |
+
)
|
| 380 |
+
data_to_write.append({"hash_id": hash_id, "embedding": embedding})
|
| 381 |
with self.driver.session() as session:
|
| 382 |
+
session.execute_write(
|
| 383 |
+
lambda tx: tx.run(write_query, data=data_to_write)
|
| 384 |
+
)
|
| 385 |
+
offset += batch_size
|
| 386 |
+
logger.info(f"== Processed batch starting at offset {offset} ==")
|
| 387 |
|
| 388 |
+
def add_paper_bg_embedding(self, embedding_model, hash_id=None, batch_size=512):
|
| 389 |
if hash_id is not None:
|
| 390 |
query = """
|
| 391 |
MATCH (p:Paper {hash_id: $hash_id})
|
| 392 |
WHERE p.motivation IS NOT NULL
|
| 393 |
+
RETURN p.motivation AS context, p.hash_id AS hash_id, p.title AS title
|
| 394 |
"""
|
| 395 |
with self.driver.session() as session:
|
| 396 |
+
results = session.execute_write(
|
| 397 |
+
lambda tx: tx.run(query, hash_id=hash_id).data()
|
| 398 |
+
)
|
| 399 |
+
contexts = [result["context"] for result in results]
|
| 400 |
+
paper_ids = [result["hash_id"] for result in results]
|
| 401 |
+
context_embeddings = embedding_model.encode(
|
| 402 |
+
contexts, convert_to_tensor=True, device=self.device
|
| 403 |
+
)
|
| 404 |
query = """
|
| 405 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
| 406 |
+
ON CREATE SET p.motivation_embedding = $embedding
|
| 407 |
+
ON MATCH SET p.motivation_embedding = $embedding
|
| 408 |
+
"""
|
| 409 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
| 410 |
+
embedding = (
|
| 411 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
| 412 |
+
)
|
| 413 |
+
with self.driver.session() as session:
|
| 414 |
+
results = session.execute_write(
|
| 415 |
+
lambda tx: tx.run(
|
| 416 |
+
query, hash_id=hash_id, embedding=embedding
|
| 417 |
+
).data()
|
| 418 |
+
)
|
| 419 |
+
return
|
| 420 |
+
offset = 0
|
| 421 |
+
while True:
|
| 422 |
+
query = f"""
|
| 423 |
MATCH (p:Paper)
|
| 424 |
WHERE p.motivation IS NOT NULL
|
| 425 |
+
RETURN p.motivation AS context, p.hash_id AS hash_id, p.title AS title
|
| 426 |
+
SKIP $offset LIMIT $batch_size
|
| 427 |
"""
|
| 428 |
with self.driver.session() as session:
|
| 429 |
+
results = session.execute_write(
|
| 430 |
+
lambda tx: tx.run(
|
| 431 |
+
query, offset=offset, batch_size=batch_size
|
| 432 |
+
).data()
|
| 433 |
+
)
|
| 434 |
+
if not results:
|
| 435 |
+
break
|
| 436 |
+
contexts = [result["title"] + result["context"] for result in results]
|
| 437 |
+
paper_ids = [result["hash_id"] for result in results]
|
| 438 |
+
context_embeddings = embedding_model.encode(
|
| 439 |
+
contexts,
|
| 440 |
+
batch_size=batch_size,
|
| 441 |
+
convert_to_tensor=True,
|
| 442 |
+
device=self.device,
|
| 443 |
+
)
|
| 444 |
+
write_query = """
|
| 445 |
+
UNWIND $data AS row
|
| 446 |
+
MERGE (p:Paper {hash_id: row.hash_id})
|
| 447 |
+
ON CREATE SET p.motivation_embedding = row.embedding
|
| 448 |
+
ON MATCH SET p.motivation_embedding = row.embedding
|
| 449 |
+
"""
|
| 450 |
+
data_to_write = []
|
| 451 |
+
for idx, hash_id in enumerate(paper_ids):
|
| 452 |
+
embedding = (
|
| 453 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
| 454 |
+
)
|
| 455 |
+
data_to_write.append({"hash_id": hash_id, "embedding": embedding})
|
| 456 |
with self.driver.session() as session:
|
| 457 |
+
session.execute_write(
|
| 458 |
+
lambda tx: tx.run(write_query, data=data_to_write)
|
| 459 |
+
)
|
| 460 |
+
offset += batch_size
|
| 461 |
+
logger.info(f"== Processed batch starting at offset {offset} ==")
|
| 462 |
+
|
| 463 |
+
def add_paper_contribution_embedding(
|
| 464 |
+
self, embedding_model, hash_id=None, batch_size=512
|
| 465 |
+
):
|
| 466 |
if hash_id is not None:
|
| 467 |
query = """
|
| 468 |
MATCH (p:Paper {hash_id: $hash_id})
|
| 469 |
WHERE p.contribution IS NOT NULL
|
| 470 |
+
RETURN p.contribution AS context, p.hash_id AS hash_id, p.title AS title
|
| 471 |
"""
|
| 472 |
with self.driver.session() as session:
|
| 473 |
+
results = session.execute_write(
|
| 474 |
+
lambda tx: tx.run(query, hash_id=hash_id).data()
|
| 475 |
+
)
|
| 476 |
+
contexts = [result["context"] for result in results]
|
| 477 |
+
paper_ids = [result["hash_id"] for result in results]
|
| 478 |
+
context_embeddings = embedding_model.encode(
|
| 479 |
+
contexts, convert_to_tensor=True, device=self.device
|
| 480 |
+
)
|
| 481 |
query = """
|
| 482 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
| 483 |
+
ON CREATE SET p.contribution_embedding = $embedding
|
| 484 |
+
ON MATCH SET p.contribution_embedding = $embedding
|
| 485 |
+
"""
|
| 486 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
| 487 |
+
embedding = (
|
| 488 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
| 489 |
+
)
|
| 490 |
+
with self.driver.session() as session:
|
| 491 |
+
results = session.execute_write(
|
| 492 |
+
lambda tx: tx.run(
|
| 493 |
+
query, hash_id=hash_id, embedding=embedding
|
| 494 |
+
).data()
|
| 495 |
+
)
|
| 496 |
+
return
|
| 497 |
+
offset = 0
|
| 498 |
+
while True:
|
| 499 |
+
query = f"""
|
| 500 |
MATCH (p:Paper)
|
| 501 |
WHERE p.contribution IS NOT NULL
|
| 502 |
+
RETURN p.contribution AS context, p.hash_id AS hash_id, p.title AS title
|
| 503 |
+
SKIP $offset LIMIT $batch_size
|
| 504 |
"""
|
| 505 |
with self.driver.session() as session:
|
| 506 |
+
results = session.execute_write(
|
| 507 |
+
lambda tx: tx.run(
|
| 508 |
+
query, offset=offset, batch_size=batch_size
|
| 509 |
+
).data()
|
| 510 |
+
)
|
| 511 |
+
if not results:
|
| 512 |
+
break
|
| 513 |
+
contexts = [result["context"] for result in results]
|
| 514 |
+
paper_ids = [result["hash_id"] for result in results]
|
| 515 |
+
context_embeddings = embedding_model.encode(
|
| 516 |
+
contexts,
|
| 517 |
+
batch_size=batch_size,
|
| 518 |
+
convert_to_tensor=True,
|
| 519 |
+
device=self.device,
|
| 520 |
+
)
|
| 521 |
+
write_query = """
|
| 522 |
+
UNWIND $data AS row
|
| 523 |
+
MERGE (p:Paper {hash_id: row.hash_id})
|
| 524 |
+
ON CREATE SET p.contribution_embedding = row.embedding
|
| 525 |
+
ON MATCH SET p.contribution_embedding = row.embedding
|
| 526 |
+
"""
|
| 527 |
+
data_to_write = []
|
| 528 |
+
for idx, hash_id in enumerate(paper_ids):
|
| 529 |
+
embedding = (
|
| 530 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
| 531 |
+
)
|
| 532 |
+
data_to_write.append({"hash_id": hash_id, "embedding": embedding})
|
| 533 |
with self.driver.session() as session:
|
| 534 |
+
session.execute_write(
|
| 535 |
+
lambda tx: tx.run(write_query, data=data_to_write)
|
| 536 |
+
)
|
| 537 |
+
offset += batch_size
|
| 538 |
+
logger.info(f"== Processed batch starting at offset {offset} ==")
|
| 539 |
+
|
| 540 |
+
def add_paper_summary_embedding(
|
| 541 |
+
self, embedding_model, hash_id=None, batch_size=512
|
| 542 |
+
):
|
| 543 |
if hash_id is not None:
|
| 544 |
query = """
|
| 545 |
MATCH (p:Paper {hash_id: $hash_id})
|
| 546 |
WHERE p.summary IS NOT NULL
|
| 547 |
+
RETURN p.summary AS context, p.hash_id AS hash_id, p.title AS title
|
| 548 |
"""
|
| 549 |
with self.driver.session() as session:
|
| 550 |
+
results = session.execute_write(
|
| 551 |
+
lambda tx: tx.run(query, hash_id=hash_id).data()
|
| 552 |
+
)
|
| 553 |
+
contexts = [result["context"] for result in results]
|
| 554 |
+
paper_ids = [result["hash_id"] for result in results]
|
| 555 |
+
context_embeddings = embedding_model.encode(
|
| 556 |
+
contexts, convert_to_tensor=True, device=self.device
|
| 557 |
+
)
|
| 558 |
query = """
|
| 559 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
| 560 |
+
ON CREATE SET p.summary_embedding = $embedding
|
| 561 |
+
ON MATCH SET p.summary_embedding = $embedding
|
| 562 |
+
"""
|
| 563 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
| 564 |
+
embedding = (
|
| 565 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
| 566 |
+
)
|
| 567 |
+
with self.driver.session() as session:
|
| 568 |
+
results = session.execute_write(
|
| 569 |
+
lambda tx: tx.run(
|
| 570 |
+
query, hash_id=hash_id, embedding=embedding
|
| 571 |
+
).data()
|
| 572 |
+
)
|
| 573 |
+
return
|
| 574 |
+
offset = 0
|
| 575 |
+
while True:
|
| 576 |
+
query = f"""
|
| 577 |
MATCH (p:Paper)
|
| 578 |
WHERE p.summary IS NOT NULL
|
| 579 |
+
RETURN p.summary AS context, p.hash_id AS hash_id, p.title AS title
|
| 580 |
+
SKIP $offset LIMIT $batch_size
|
| 581 |
"""
|
| 582 |
with self.driver.session() as session:
|
| 583 |
+
results = session.execute_write(
|
| 584 |
+
lambda tx: tx.run(
|
| 585 |
+
query, offset=offset, batch_size=batch_size
|
| 586 |
+
).data()
|
| 587 |
+
)
|
| 588 |
+
if not results:
|
| 589 |
+
break
|
| 590 |
+
contexts = [result["context"] for result in results]
|
| 591 |
+
paper_ids = [result["hash_id"] for result in results]
|
| 592 |
+
context_embeddings = embedding_model.encode(
|
| 593 |
+
contexts,
|
| 594 |
+
batch_size=batch_size,
|
| 595 |
+
convert_to_tensor=True,
|
| 596 |
+
device=self.device,
|
| 597 |
+
)
|
| 598 |
+
write_query = """
|
| 599 |
+
UNWIND $data AS row
|
| 600 |
+
MERGE (p:Paper {hash_id: row.hash_id})
|
| 601 |
+
ON CREATE SET p.summary_embedding = row.embedding
|
| 602 |
+
ON MATCH SET p.summary_embedding = row.embedding
|
| 603 |
+
"""
|
| 604 |
+
data_to_write = []
|
| 605 |
+
for idx, hash_id in enumerate(paper_ids):
|
| 606 |
+
embedding = (
|
| 607 |
+
context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
| 608 |
+
)
|
| 609 |
+
data_to_write.append({"hash_id": hash_id, "embedding": embedding})
|
| 610 |
with self.driver.session() as session:
|
| 611 |
+
session.execute_write(
|
| 612 |
+
lambda tx: tx.run(write_query, data=data_to_write)
|
| 613 |
+
)
|
| 614 |
+
offset += batch_size
|
| 615 |
+
logger.info(f"== Processed batch starting at offset {offset} ==")
|
| 616 |
+
|
| 617 |
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
|
| 618 |
query = f"""
|
| 619 |
MATCH (paper:Paper)
|
|
|
|
| 624 |
ORDER BY score DESC LIMIT {k}
|
| 625 |
"""
|
| 626 |
with self.driver.session() as session:
|
| 627 |
+
results = session.execute_read(
|
| 628 |
+
lambda tx: tx.run(query, embedding=embedding).data()
|
| 629 |
+
)
|
| 630 |
+
related_paper = []
|
| 631 |
for result in results:
|
| 632 |
related_paper.append(result["paper"]["hash_id"])
|
| 633 |
return related_paper
|
|
|
|
| 649 |
"""
|
| 650 |
with self.driver.session() as session:
|
| 651 |
session.execute_write(lambda tx: tx.run(query).data())
|
| 652 |
+
|
| 653 |
def filter_paper_id_list(self, paper_id_list, year="2024"):
|
| 654 |
if not paper_id_list:
|
| 655 |
return []
|
|
|
|
| 661 |
RETURN p.hash_id AS hash_id
|
| 662 |
"""
|
| 663 |
with self.driver.session() as session:
|
| 664 |
+
result = session.execute_read(
|
| 665 |
+
lambda tx: tx.run(query, paper_id_list=paper_id_list, year=year).data()
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
existing_paper_ids = [record["hash_id"] for record in result]
|
| 669 |
existing_paper_ids = list(set(existing_paper_ids))
|
| 670 |
return existing_paper_ids
|
| 671 |
+
|
| 672 |
def check_index_exists(self):
|
| 673 |
query = "SHOW INDEXES"
|
| 674 |
with self.driver.session() as session:
|
|
|
|
| 685 |
"""
|
| 686 |
with self.driver.session() as session:
|
| 687 |
session.execute_write(lambda tx: tx.run(query).data())
|
| 688 |
+
|
| 689 |
def get_entity_related_paper_num(self, entity_name):
|
| 690 |
query = """
|
| 691 |
MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
|
|
|
|
| 693 |
RETURN PaperCount
|
| 694 |
"""
|
| 695 |
with self.driver.session() as session:
|
| 696 |
+
result = session.execute_read(
|
| 697 |
+
lambda tx: tx.run(query, entity_name=entity_name).data()
|
| 698 |
+
)
|
| 699 |
+
paper_num = result[0]["PaperCount"]
|
| 700 |
return paper_num
|
| 701 |
|
| 702 |
+
def get_entities_related_paper_num(self, entity_names):
|
| 703 |
+
query = """
|
| 704 |
+
UNWIND $entity_names AS entity_name
|
| 705 |
+
MATCH (e:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)
|
| 706 |
+
WITH entity_name, COUNT(p) AS PaperCount
|
| 707 |
+
RETURN entity_name, PaperCount
|
| 708 |
+
"""
|
| 709 |
+
|
| 710 |
+
with self.driver.session() as session:
|
| 711 |
+
result = session.execute_read(
|
| 712 |
+
lambda tx: tx.run(query, entity_names=entity_names).data()
|
| 713 |
+
)
|
| 714 |
+
# 将查询结果转化为字典形式:实体名称 -> 论文数量
|
| 715 |
+
entity_paper_count = {
|
| 716 |
+
record["entity_name"]: record["PaperCount"] for record in result
|
| 717 |
+
}
|
| 718 |
+
return entity_paper_count
|
| 719 |
+
|
| 720 |
def get_entity_text(self):
|
| 721 |
query = """
|
| 722 |
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
|
|
|
|
| 726 |
"""
|
| 727 |
with self.driver.session() as session:
|
| 728 |
result = session.execute_read(lambda tx: tx.run(query).data())
|
| 729 |
+
text_list = [record["entity_text"] for record in result]
|
| 730 |
return text_list
|
| 731 |
+
|
| 732 |
def get_entity_combinations(self, venue_name, year):
|
| 733 |
+
def process_paper_relationships(
|
| 734 |
+
session, entity_name_1, entity_name_2, abstract
|
| 735 |
+
):
|
| 736 |
if entity_name_2 < entity_name_1:
|
| 737 |
entity_name_1, entity_name_2 = entity_name_2, entity_name_1
|
| 738 |
query = """
|
|
|
|
| 742 |
ON CREATE SET r.strength = 1
|
| 743 |
ON MATCH SET r.strength = r.strength + 1
|
| 744 |
"""
|
| 745 |
+
sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", abstract)
|
| 746 |
for sentence in sentences:
|
| 747 |
sentence = sentence.lower()
|
| 748 |
if entity_name_1 in sentence and entity_name_2 in sentence:
|
| 749 |
# 如果两个实体在同一句话中出现过,则创建或更新 CONNECT 关系
|
| 750 |
session.execute_write(
|
| 751 |
+
lambda tx: tx.run(
|
| 752 |
+
query,
|
| 753 |
+
entity_name_1=entity_name_1,
|
| 754 |
+
entity_name_2=entity_name_2,
|
| 755 |
+
).data()
|
| 756 |
)
|
| 757 |
# logger.debug(f"CONNECT relation created or updated between {entity_name_1} and {entity_name_2} for Paper ID {paper_id}")
|
| 758 |
break # 如果找到一次出现就可以退出循环
|
|
|
|
| 766 |
RETURN p.hash_id AS hash_id, entities[i].name AS entity_name_1, entities[j].name AS entity_name_2
|
| 767 |
"""
|
| 768 |
with self.driver.session() as session:
|
| 769 |
+
result = session.execute_read(
|
| 770 |
+
lambda tx: tx.run(query, venue_name=venue_name, year=year).data()
|
| 771 |
+
)
|
| 772 |
for record in tqdm(result):
|
| 773 |
paper_id = record["hash_id"]
|
| 774 |
+
entity_name_1 = record["entity_name_1"]
|
| 775 |
+
entity_name_2 = record["entity_name_2"]
|
| 776 |
abstract = self.get_paper_attribute(paper_id, "abstract")
|
| 777 |
+
process_paper_relationships(
|
| 778 |
+
session, entity_name_1, entity_name_2, abstract
|
| 779 |
+
)
|
| 780 |
|
| 781 |
def build_citemap(self):
|
| 782 |
citemap = defaultdict(set)
|
|
|
|
| 787 |
with self.driver.session() as session:
|
| 788 |
results = session.execute_read(lambda tx: tx.run(query).data())
|
| 789 |
for result in results:
|
| 790 |
+
hash_id = result["hash_id"]
|
| 791 |
+
cite_id_list = result["cite_id_list"]
|
| 792 |
if cite_id_list:
|
| 793 |
for cited_id in cite_id_list:
|
| 794 |
citemap[hash_id].add(cited_id)
|
|
|
|
| 801 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
| 802 |
graph = Graph(URI, auth=AUTH)
|
| 803 |
# 创建一个字典来保存数据
|
| 804 |
+
# 定义批次大小
|
| 805 |
data = {"nodes": [], "relationships": []}
|
| 806 |
+
# 计算数据的总数(例如查询节点总数)
|
| 807 |
+
total_papers_query = "MATCH (e:Entity)-[:RELATED_TO]->(p:Paper) RETURN COUNT(DISTINCT p) AS count"
|
| 808 |
+
total_papers = graph.run(total_papers_query).evaluate()
|
| 809 |
+
print(f"total paper: {total_papers}")
|
| 810 |
+
query = f"""
|
| 811 |
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
|
|
|
| 812 |
RETURN p, e, r
|
| 813 |
"""
|
| 814 |
+
"""
|
| 815 |
results = graph.run(query)
|
| 816 |
# 处理查询结果
|
| 817 |
for record in tqdm(results):
|
|
|
|
| 819 |
entity_node = record["e"]
|
| 820 |
relationship = record["r"]
|
| 821 |
# 将节点数据加入字典
|
| 822 |
+
data["nodes"].append(
|
| 823 |
+
{
|
| 824 |
+
"id": paper_node.identity,
|
| 825 |
+
"label": "Paper",
|
| 826 |
+
"properties": dict(paper_node),
|
| 827 |
+
}
|
| 828 |
+
)
|
| 829 |
+
data["nodes"].append(
|
| 830 |
+
{
|
| 831 |
+
"id": entity_node.identity,
|
| 832 |
+
"label": "Entity",
|
| 833 |
+
"properties": dict(entity_node),
|
| 834 |
+
}
|
| 835 |
+
)
|
| 836 |
# 将关系数据加入字典
|
| 837 |
+
data["relationships"].append(
|
| 838 |
+
{
|
| 839 |
+
"start_node": entity_node.identity,
|
| 840 |
+
"end_node": paper_node.identity,
|
| 841 |
+
"type": "RELATED_TO",
|
| 842 |
+
"properties": dict(relationship),
|
| 843 |
+
}
|
| 844 |
+
)
|
| 845 |
+
"""
|
| 846 |
query = """
|
| 847 |
MATCH (p:Paper)
|
| 848 |
WHERE p.venue_name='acl' and p.year='2024'
|
| 849 |
RETURN p
|
| 850 |
"""
|
|
|
|
| 851 |
results = graph.run(query)
|
| 852 |
for record in tqdm(results):
|
| 853 |
paper_node = record["p"]
|
| 854 |
# 将节点数据加入字典
|
| 855 |
+
data["nodes"].append(
|
| 856 |
+
{
|
| 857 |
+
"id": paper_node.identity,
|
| 858 |
+
"label": "Paper",
|
| 859 |
+
"properties": dict(paper_node),
|
| 860 |
+
}
|
| 861 |
+
)
|
| 862 |
# 去除重复节点
|
| 863 |
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
| 864 |
unique_nodes = []
|
|
|
|
| 871 |
unique_nodes.append(node)
|
| 872 |
data["nodes"] = unique_nodes
|
| 873 |
# 将数据保存为 JSON 文件
|
| 874 |
+
with open(
|
| 875 |
+
"./assets/data/scipip_neo4j_clean_backup.json", "w", encoding="utf-8"
|
| 876 |
+
) as f:
|
| 877 |
json.dump(data, f, ensure_ascii=False, indent=4)
|
| 878 |
+
|
| 879 |
def neo4j_import_data(self):
|
| 880 |
# clear_database() # 清空数据库,谨慎执行
|
| 881 |
URI = os.environ["NEO4J_URL"]
|
|
|
|
| 884 |
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
| 885 |
graph = Graph(URI, auth=AUTH)
|
| 886 |
# 从 JSON 文件中读取数据
|
| 887 |
+
with open(
|
| 888 |
+
"./assets/data/scipip_neo4j_clean_backup.json", "r", encoding="utf-8"
|
| 889 |
+
) as f:
|
| 890 |
data = json.load(f)
|
| 891 |
# 创建节点
|
| 892 |
nodes = {}
|
src/utils/paper_retriever.py
CHANGED
|
@@ -59,6 +59,7 @@ class CoCite:
|
|
| 59 |
|
| 60 |
def __init__(self) -> None:
|
| 61 |
if not self._initialized:
|
|
|
|
| 62 |
self.paper_client = PaperClient()
|
| 63 |
citemap = self.paper_client.build_citemap()
|
| 64 |
self.comap = defaultdict(lambda: defaultdict(int))
|
|
@@ -101,20 +102,16 @@ class Retriever(object):
|
|
| 101 |
|
| 102 |
def retrieve_entities_by_enties(self, entities):
|
| 103 |
# TODO: KG
|
| 104 |
-
expand_entities =
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
relation_name=self.config.RETRIEVE.relation_name,
|
| 111 |
-
)
|
| 112 |
expand_entities = list(set(entities + expand_entities))
|
| 113 |
-
entity_paper_num_dict =
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
self.paper_client.get_entity_related_paper_num(entity)
|
| 117 |
-
)
|
| 118 |
new_entities = []
|
| 119 |
entity_paper_num_dict = {
|
| 120 |
k: v for k, v in entity_paper_num_dict.items() if v != 0
|
|
@@ -142,11 +139,7 @@ class Retriever(object):
|
|
| 142 |
Return:
|
| 143 |
related_paper: list(dict)
|
| 144 |
"""
|
| 145 |
-
related_paper =
|
| 146 |
-
for paper_id in paper_id_list:
|
| 147 |
-
paper = {"hash_id": paper_id}
|
| 148 |
-
self.paper_client.update_paper_from_client(paper)
|
| 149 |
-
related_paper.append(paper)
|
| 150 |
return related_paper
|
| 151 |
|
| 152 |
def calculate_similarity(self, entities, related_entities_list, use_weight=False):
|
|
@@ -333,7 +326,6 @@ class Retriever(object):
|
|
| 333 |
similarity_threshold = self.config.RETRIEVE.similarity_threshold
|
| 334 |
similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
|
| 335 |
target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
|
| 336 |
-
# target_labels = list(range(0, len(target_paper_id_list)))
|
| 337 |
target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
|
| 338 |
logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
|
| 339 |
logger.debug(
|
|
@@ -672,8 +664,7 @@ class SNKGRetriever(Retriever):
|
|
| 672 |
)
|
| 673 |
related_paper = set()
|
| 674 |
related_paper.update(sn_paper_id_list)
|
| 675 |
-
|
| 676 |
-
sn_entities += self.paper_client.find_entities_by_paper(paper_id)
|
| 677 |
logger.debug("SN entities for retriever: {}".format(sn_entities))
|
| 678 |
entities = list(set(entities + sn_entities))
|
| 679 |
new_entities = self.retrieve_entities_by_enties(entities)
|
|
|
|
| 59 |
|
| 60 |
def __init__(self) -> None:
|
| 61 |
if not self._initialized:
|
| 62 |
+
logger.debug("init co-cite map begin...")
|
| 63 |
self.paper_client = PaperClient()
|
| 64 |
citemap = self.paper_client.build_citemap()
|
| 65 |
self.comap = defaultdict(lambda: defaultdict(int))
|
|
|
|
| 102 |
|
| 103 |
def retrieve_entities_by_enties(self, entities):
|
| 104 |
# TODO: KG
|
| 105 |
+
expand_entities = self.paper_client.find_related_entities_by_entity_list(
|
| 106 |
+
entities,
|
| 107 |
+
n=self.config.RETRIEVE.kg_jump_num,
|
| 108 |
+
k=self.config.RETRIEVE.kg_cover_num,
|
| 109 |
+
relation_name=self.config.RETRIEVE.relation_name,
|
| 110 |
+
)
|
|
|
|
|
|
|
| 111 |
expand_entities = list(set(entities + expand_entities))
|
| 112 |
+
entity_paper_num_dict = self.paper_client.get_entities_related_paper_num(
|
| 113 |
+
expand_entities
|
| 114 |
+
)
|
|
|
|
|
|
|
| 115 |
new_entities = []
|
| 116 |
entity_paper_num_dict = {
|
| 117 |
k: v for k, v in entity_paper_num_dict.items() if v != 0
|
|
|
|
| 139 |
Return:
|
| 140 |
related_paper: list(dict)
|
| 141 |
"""
|
| 142 |
+
related_paper = self.paper_client.update_papers_from_client(paper_id_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
return related_paper
|
| 144 |
|
| 145 |
def calculate_similarity(self, entities, related_entities_list, use_weight=False):
|
|
|
|
| 326 |
similarity_threshold = self.config.RETRIEVE.similarity_threshold
|
| 327 |
similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
|
| 328 |
target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
|
|
|
|
| 329 |
target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
|
| 330 |
logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
|
| 331 |
logger.debug(
|
|
|
|
| 664 |
)
|
| 665 |
related_paper = set()
|
| 666 |
related_paper.update(sn_paper_id_list)
|
| 667 |
+
sn_entities += self.paper_client.find_entities_by_paper_list(sn_paper_id_list)
|
|
|
|
| 668 |
logger.debug("SN entities for retriever: {}".format(sn_entities))
|
| 669 |
entities = list(set(entities + sn_entities))
|
| 670 |
new_entities = self.retrieve_entities_by_enties(entities)
|