Spaces:
Runtime error
Runtime error
Commit
·
e3a17c0
1
Parent(s):
a81bf47
delete async
Browse files- .gitattributes copy +0 -35
- LLM.py +43 -8
- agents.py +34 -73
- app.py +1 -1
- main.py +4 -9
- searcher/sementic_search.py +49 -121
.gitattributes copy
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLM.py
CHANGED
|
@@ -123,7 +123,13 @@ class openai_llm(base_llm):
|
|
| 123 |
input=text,
|
| 124 |
timeout= 180
|
| 125 |
)
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
except Exception as e:
|
| 128 |
print(f"get embbeding failed: {e}")
|
| 129 |
print(e)
|
|
@@ -147,7 +153,13 @@ class openai_llm(base_llm):
|
|
| 147 |
input=text,
|
| 148 |
timeout= 180
|
| 149 |
)
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
except Exception as e:
|
| 152 |
await asyncio.sleep(0.1)
|
| 153 |
print(f"get embbeding failed: {e}")
|
|
@@ -178,9 +190,32 @@ class openai_llm(base_llm):
|
|
| 178 |
|
| 179 |
|
| 180 |
if __name__ == "__main__":
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
input=text,
|
| 124 |
timeout= 180
|
| 125 |
)
|
| 126 |
+
embbeding = embbeding.data
|
| 127 |
+
if len(embbeding) == 0:
|
| 128 |
+
return None
|
| 129 |
+
elif len(embbeding) == 1:
|
| 130 |
+
return embbeding[0].embedding
|
| 131 |
+
else:
|
| 132 |
+
return [e.embedding for e in embbeding]
|
| 133 |
except Exception as e:
|
| 134 |
print(f"get embbeding failed: {e}")
|
| 135 |
print(e)
|
|
|
|
| 153 |
input=text,
|
| 154 |
timeout= 180
|
| 155 |
)
|
| 156 |
+
embbeding = embbeding.data
|
| 157 |
+
if len(embbeding) == 0:
|
| 158 |
+
return None
|
| 159 |
+
elif len(embbeding) == 1:
|
| 160 |
+
return embbeding[0].embedding
|
| 161 |
+
else:
|
| 162 |
+
return [e.embedding for e in embbeding]
|
| 163 |
except Exception as e:
|
| 164 |
await asyncio.sleep(0.1)
|
| 165 |
print(f"get embbeding failed: {e}")
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
if __name__ == "__main__":
|
| 193 |
+
import os
|
| 194 |
+
import yaml
|
| 195 |
+
|
| 196 |
+
def cal_cosine_similarity_matric(matric1, matric2):
|
| 197 |
+
if isinstance(matric1, list):
|
| 198 |
+
matric1 = np.array(matric1)
|
| 199 |
+
if isinstance(matric2, list):
|
| 200 |
+
matric2 = np.array(matric2)
|
| 201 |
+
if len(matric1.shape) == 1:
|
| 202 |
+
matric1 = matric1.reshape(1, -1)
|
| 203 |
+
if len(matric2.shape) == 1:
|
| 204 |
+
matric2 = matric2.reshape(1, -1)
|
| 205 |
+
dot_product = np.dot(matric1, matric2.T)
|
| 206 |
+
norm1 = np.linalg.norm(matric1, axis=1)
|
| 207 |
+
norm2 = np.linalg.norm(matric2, axis=1)
|
| 208 |
+
|
| 209 |
+
cos_sim = dot_product / np.outer(norm1, norm2)
|
| 210 |
+
scores = cos_sim.flatten()
|
| 211 |
+
# 返回一个list
|
| 212 |
+
return scores.tolist()
|
| 213 |
+
|
| 214 |
+
texts = ["What is the capital of France?","What is the capital of Spain?", "What is the capital of Italy?", "What is the capital of Germany?"]
|
| 215 |
+
text = "What is the capital of France?"
|
| 216 |
+
llm = openai_llm()
|
| 217 |
+
embbedings = llm.get_embbeding(texts)
|
| 218 |
+
embbeding = llm.get_embbeding(text)
|
| 219 |
+
|
| 220 |
+
scores = cal_cosine_similarity_matric(embbedings, embbeding)
|
| 221 |
+
print(scores)
|
agents.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import json
|
| 2 |
import time
|
| 3 |
-
import asyncio
|
| 4 |
-
import os
|
| 5 |
from searcher import Result,SementicSearcher
|
| 6 |
from LLM import openai_llm
|
| 7 |
from prompts import *
|
|
@@ -17,10 +15,10 @@ def get_llms():
|
|
| 17 |
cheap_llm = get_llm("gpt-4o-mini")
|
| 18 |
return main_llm,cheap_llm
|
| 19 |
|
| 20 |
-
|
| 21 |
prompt = get_judge_idea_all_prompt(idea0,idea1,topic)
|
| 22 |
messages = [{"role":"user","content":prompt}]
|
| 23 |
-
response =
|
| 24 |
novelty = extract(response,"novelty")
|
| 25 |
relevance = extract(response,"relevance")
|
| 26 |
significance = extract(response,"significance")
|
|
@@ -55,16 +53,16 @@ class DeepResearchAgent:
|
|
| 55 |
def wrap_messages(self,prompt):
|
| 56 |
return [{"role":"user","content":prompt}]
|
| 57 |
|
| 58 |
-
|
| 59 |
-
return
|
| 60 |
|
| 61 |
-
|
| 62 |
-
return
|
| 63 |
|
| 64 |
-
|
| 65 |
prompt = get_deep_search_query_prompt(topic,query)
|
| 66 |
messages = self.wrap_messages(prompt)
|
| 67 |
-
response =
|
| 68 |
search_query = extract(response,"queries")
|
| 69 |
try:
|
| 70 |
search_query = json.loads(search_query)
|
|
@@ -73,17 +71,17 @@ class DeepResearchAgent:
|
|
| 73 |
search_query = [query]
|
| 74 |
return search_query
|
| 75 |
|
| 76 |
-
|
| 77 |
self.topic = topic
|
| 78 |
print(f"begin to generate search query for {topic}")
|
| 79 |
-
search_query =
|
| 80 |
papers = []
|
| 81 |
for query in search_query:
|
| 82 |
failed_query = []
|
| 83 |
current_papers = []
|
| 84 |
cnt = 0
|
| 85 |
while len(current_papers) == 0 and cnt < 10:
|
| 86 |
-
paper =
|
| 87 |
if paper and len(paper) > 0 and paper[0]:
|
| 88 |
self.read_papers.add(paper[0].title)
|
| 89 |
current_papers.append(paper[0])
|
|
@@ -91,7 +89,7 @@ class DeepResearchAgent:
|
|
| 91 |
failed_query.append(query)
|
| 92 |
prompt = get_deep_rewrite_query_prompt(failed_query,topic)
|
| 93 |
messages = self.wrap_messages(prompt)
|
| 94 |
-
new_query =
|
| 95 |
new_query = extract(new_query,"query")
|
| 96 |
print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.")
|
| 97 |
query = new_query
|
|
@@ -104,67 +102,30 @@ class DeepResearchAgent:
|
|
| 104 |
print(f"failed to generate idea {topic}")
|
| 105 |
return None,None,None,None,None,None,None,None,None
|
| 106 |
|
| 107 |
-
|
| 108 |
-
results = await asyncio.gather(*tasks)
|
| 109 |
-
results = [result for result in results if result]
|
| 110 |
-
if len(results) ==0:
|
| 111 |
-
print(f"failed to generate idea {topic}")
|
| 112 |
-
return None,None,None,None,None,None,None,None,None
|
| 113 |
-
|
| 114 |
-
ideas,idea_chains,experiments,entities,trends,futures,humans,years = [[result[i] for result in results] for i in range(8)]
|
| 115 |
-
|
| 116 |
-
tasks = []
|
| 117 |
-
for i,idea_1 in enumerate(ideas):
|
| 118 |
-
for j,idea_2 in enumerate(ideas):
|
| 119 |
-
if i != j:
|
| 120 |
-
tasks.append(judge_idea(i,j,idea_1,idea_2,topic,self.llm))
|
| 121 |
-
results = await asyncio.gather(*tasks)
|
| 122 |
-
elo_scores = [0 for _ in range(len(ideas))]
|
| 123 |
-
elo_selected = 0
|
| 124 |
-
def change_winner_to_score(winner,score_1,score_2):
|
| 125 |
-
try:
|
| 126 |
-
winner = int(winner)
|
| 127 |
-
except:
|
| 128 |
-
return score_1+0.5,score_2+0.5
|
| 129 |
-
if winner == 0:
|
| 130 |
-
return score_1+1,score_2
|
| 131 |
-
if winner == 2:
|
| 132 |
-
return score_1+0.5,score_2+0.5
|
| 133 |
-
return score_1,score_2+1
|
| 134 |
-
for result in results:
|
| 135 |
-
i,j,novelty,relevance,significance,clarity,feasibility,effectiveness = result
|
| 136 |
-
for dimension in [novelty,relevance,significance,clarity,feasibility,effectiveness]:
|
| 137 |
-
elo_scores[i],elo_scores[j] = change_winner_to_score(dimension,elo_scores[i],elo_scores[j])
|
| 138 |
-
print(f"i:{i},j:{j},novelty:{novelty},relevance:{relevance},significance:{significance},clarity:{clarity},feasibility:{feasibility},effectiveness:{effectiveness}")
|
| 139 |
-
print(elo_scores)
|
| 140 |
-
try:
|
| 141 |
-
elo_selected = elo_scores.index(max(elo_scores))
|
| 142 |
-
except:
|
| 143 |
-
elo_selected = 0
|
| 144 |
|
| 145 |
-
idea,experiment,entities,idea_chain,trend,future,human,year = ideas[elo_selected],experiments[elo_selected],entities[elo_selected],idea_chains[elo_selected],trends[elo_selected],futures[elo_selected],humans[elo_selected],years[elo_selected]
|
| 146 |
print(f"successfully generated idea")
|
| 147 |
-
return idea,experiment,entities,idea_chain,
|
| 148 |
|
| 149 |
-
|
| 150 |
article = paper.article
|
| 151 |
if not article:
|
| 152 |
return None
|
| 153 |
paper_content = self.reader.read_paper_content(article)
|
| 154 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
| 155 |
messages = self.wrap_messages(prompt)
|
| 156 |
-
response =
|
| 157 |
entities = extract(response,"entities")
|
| 158 |
idea = extract(response,"idea")
|
| 159 |
experiment = extract(response,"experiment")
|
| 160 |
references = extract(response,"references")
|
| 161 |
return idea,experiment,entities,references,paper.title
|
| 162 |
|
| 163 |
-
|
| 164 |
paper_content = self.reader.read_paper_content_with_ref(article)
|
| 165 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
| 166 |
messages = self.wrap_messages(prompt)
|
| 167 |
-
response =
|
| 168 |
entities = extract(response,"entities")
|
| 169 |
idea = extract(response,"idea")
|
| 170 |
experiment = extract(response,"experiment")
|
|
@@ -172,7 +133,7 @@ class DeepResearchAgent:
|
|
| 172 |
return idea,experiment,entities,references
|
| 173 |
|
| 174 |
|
| 175 |
-
|
| 176 |
print(f"begin to deep research paper {paper.title}")
|
| 177 |
article = paper.article
|
| 178 |
if not article:
|
|
@@ -183,7 +144,7 @@ class DeepResearchAgent:
|
|
| 183 |
experiments = []
|
| 184 |
total_entities = []
|
| 185 |
years = []
|
| 186 |
-
idea,experiment,entities,references =
|
| 187 |
try:
|
| 188 |
references = json.loads(references)
|
| 189 |
except:
|
|
@@ -200,7 +161,7 @@ class DeepResearchAgent:
|
|
| 200 |
# search before
|
| 201 |
while len(idea_chain)<self.max_chain_length:
|
| 202 |
rerank_query = f"{self.topic} {current_title} {current_abstract}"
|
| 203 |
-
citation_paper =
|
| 204 |
if not citation_paper:
|
| 205 |
print(f"failed to find citation paper for {current_title}")
|
| 206 |
break
|
|
@@ -208,10 +169,10 @@ class DeepResearchAgent:
|
|
| 208 |
abstract = citation_paper.abstract
|
| 209 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
| 210 |
messages = self.wrap_messages(prompt)
|
| 211 |
-
response =
|
| 212 |
relevant = extract(response,"relevant")
|
| 213 |
if relevant != "0":
|
| 214 |
-
result =
|
| 215 |
if not result:
|
| 216 |
break
|
| 217 |
idea,experiment,entities,_,_ = result
|
|
@@ -238,13 +199,13 @@ class DeepResearchAgent:
|
|
| 238 |
references.pop(0)
|
| 239 |
if reference in self.read_papers:
|
| 240 |
continue
|
| 241 |
-
search_paper =
|
| 242 |
if len(search_paper) > 0:
|
| 243 |
s_p = search_paper[0]
|
| 244 |
if s_p and s_p.title not in self.read_papers:
|
| 245 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
| 246 |
messages = self.wrap_messages(prompt)
|
| 247 |
-
response =
|
| 248 |
relevant = extract(response,"relevant")
|
| 249 |
if relevant != "0" or len(idea_chain) < self.min_chain_length:
|
| 250 |
article = s_p.article
|
|
@@ -257,7 +218,7 @@ class DeepResearchAgent:
|
|
| 257 |
|
| 258 |
if not article:
|
| 259 |
rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}"
|
| 260 |
-
search_paper =
|
| 261 |
if not search_paper:
|
| 262 |
print(f"failed to find citation paper for {current_title}")
|
| 263 |
continue
|
|
@@ -273,10 +234,10 @@ class DeepResearchAgent:
|
|
| 273 |
if s_p and s_p.title not in self.read_papers:
|
| 274 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
| 275 |
messages = self.wrap_messages(prompt)
|
| 276 |
-
response =
|
| 277 |
relevant = extract(response,"relevant")
|
| 278 |
if relevant == "1" or len(idea_chain) < self.min_chain_length:
|
| 279 |
-
article =
|
| 280 |
if not article:
|
| 281 |
continue
|
| 282 |
else:
|
|
@@ -290,7 +251,7 @@ class DeepResearchAgent:
|
|
| 290 |
paper_content = self.reader.read_paper_content_with_ref(article)
|
| 291 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
| 292 |
messages = self.wrap_messages(prompt)
|
| 293 |
-
response =
|
| 294 |
idea = extract(response,"idea")
|
| 295 |
references = extract(response,"references")
|
| 296 |
experiment = extract(response,"experiment")
|
|
@@ -317,7 +278,7 @@ class DeepResearchAgent:
|
|
| 317 |
|
| 318 |
prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic)
|
| 319 |
messages = self.wrap_messages(prompt)
|
| 320 |
-
response =
|
| 321 |
trend = extract(response,"trend")
|
| 322 |
|
| 323 |
self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years})
|
|
@@ -326,26 +287,26 @@ class DeepResearchAgent:
|
|
| 326 |
<entities> {{cleaned entities}}</entities>
|
| 327 |
"""
|
| 328 |
messages = self.wrap_messages(prompt)
|
| 329 |
-
response =
|
| 330 |
total_entities = extract(response,"entities")
|
| 331 |
bad_case = []
|
| 332 |
prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities)
|
| 333 |
messages = self.wrap_messages(prompt)
|
| 334 |
-
response =
|
| 335 |
future = extract(response,"future")
|
| 336 |
human = extract(response,"human")
|
| 337 |
|
| 338 |
|
| 339 |
prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case)
|
| 340 |
messages = self.wrap_messages(prompt)
|
| 341 |
-
response =
|
| 342 |
method = extract(response,"method")
|
| 343 |
novelty = extract(response,"novelty")
|
| 344 |
motivation = extract(response,"motivation")
|
| 345 |
idea = {"motivation":motivation,"novelty":novelty,"method":method}
|
| 346 |
prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic)
|
| 347 |
messages = self.wrap_messages(prompt)
|
| 348 |
-
response =
|
| 349 |
final_idea = extract(response,"final_idea")
|
| 350 |
|
| 351 |
idea = final_idea
|
|
|
|
| 1 |
import json
|
| 2 |
import time
|
|
|
|
|
|
|
| 3 |
from searcher import Result,SementicSearcher
|
| 4 |
from LLM import openai_llm
|
| 5 |
from prompts import *
|
|
|
|
| 15 |
cheap_llm = get_llm("gpt-4o-mini")
|
| 16 |
return main_llm,cheap_llm
|
| 17 |
|
| 18 |
+
def judge_idea(i,j,idea0,idea1,topic,llm):
|
| 19 |
prompt = get_judge_idea_all_prompt(idea0,idea1,topic)
|
| 20 |
messages = [{"role":"user","content":prompt}]
|
| 21 |
+
response = llm.response(messages)
|
| 22 |
novelty = extract(response,"novelty")
|
| 23 |
relevance = extract(response,"relevance")
|
| 24 |
significance = extract(response,"significance")
|
|
|
|
| 53 |
def wrap_messages(self,prompt):
|
| 54 |
return [{"role":"user","content":prompt}]
|
| 55 |
|
| 56 |
+
def get_openai_response(self,messages):
|
| 57 |
+
return self.llm.response(messages)
|
| 58 |
|
| 59 |
+
def get_cheap_openai_response(self,messages):
|
| 60 |
+
return self.cheap_llm.response(messages,max_tokens = 16000)
|
| 61 |
|
| 62 |
+
def get_search_query(self,topic = None,query=None):
|
| 63 |
prompt = get_deep_search_query_prompt(topic,query)
|
| 64 |
messages = self.wrap_messages(prompt)
|
| 65 |
+
response = self.get_openai_response(messages)
|
| 66 |
search_query = extract(response,"queries")
|
| 67 |
try:
|
| 68 |
search_query = json.loads(search_query)
|
|
|
|
| 71 |
search_query = [query]
|
| 72 |
return search_query
|
| 73 |
|
| 74 |
+
def generate_idea_with_chain(self,topic):
|
| 75 |
self.topic = topic
|
| 76 |
print(f"begin to generate search query for {topic}")
|
| 77 |
+
search_query = self.get_search_query(topic=topic)
|
| 78 |
papers = []
|
| 79 |
for query in search_query:
|
| 80 |
failed_query = []
|
| 81 |
current_papers = []
|
| 82 |
cnt = 0
|
| 83 |
while len(current_papers) == 0 and cnt < 10:
|
| 84 |
+
paper = self.reader.search(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData)
|
| 85 |
if paper and len(paper) > 0 and paper[0]:
|
| 86 |
self.read_papers.add(paper[0].title)
|
| 87 |
current_papers.append(paper[0])
|
|
|
|
| 89 |
failed_query.append(query)
|
| 90 |
prompt = get_deep_rewrite_query_prompt(failed_query,topic)
|
| 91 |
messages = self.wrap_messages(prompt)
|
| 92 |
+
new_query = self.get_openai_response(messages)
|
| 93 |
new_query = extract(new_query,"query")
|
| 94 |
print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.")
|
| 95 |
query = new_query
|
|
|
|
| 102 |
print(f"failed to generate idea {topic}")
|
| 103 |
return None,None,None,None,None,None,None,None,None
|
| 104 |
|
| 105 |
+
idea,idea_chain,experiment,entities,trend,future,human,year = self.deep_research_paper_with_chain(papers[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
|
|
|
| 107 |
print(f"successfully generated idea")
|
| 108 |
+
return idea,experiment,entities,idea_chain,idea,trend,future,human,year
|
| 109 |
|
| 110 |
+
def get_paper_idea_experiment_references_info(self,paper):
|
| 111 |
article = paper.article
|
| 112 |
if not article:
|
| 113 |
return None
|
| 114 |
paper_content = self.reader.read_paper_content(article)
|
| 115 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
| 116 |
messages = self.wrap_messages(prompt)
|
| 117 |
+
response = self.get_cheap_openai_response(messages)
|
| 118 |
entities = extract(response,"entities")
|
| 119 |
idea = extract(response,"idea")
|
| 120 |
experiment = extract(response,"experiment")
|
| 121 |
references = extract(response,"references")
|
| 122 |
return idea,experiment,entities,references,paper.title
|
| 123 |
|
| 124 |
+
def get_article_idea_experiment_references_info(self,article):
|
| 125 |
paper_content = self.reader.read_paper_content_with_ref(article)
|
| 126 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
| 127 |
messages = self.wrap_messages(prompt)
|
| 128 |
+
response = self.get_cheap_openai_response(messages)
|
| 129 |
entities = extract(response,"entities")
|
| 130 |
idea = extract(response,"idea")
|
| 131 |
experiment = extract(response,"experiment")
|
|
|
|
| 133 |
return idea,experiment,entities,references
|
| 134 |
|
| 135 |
|
| 136 |
+
def deep_research_paper_with_chain(self,paper:Result):
|
| 137 |
print(f"begin to deep research paper {paper.title}")
|
| 138 |
article = paper.article
|
| 139 |
if not article:
|
|
|
|
| 144 |
experiments = []
|
| 145 |
total_entities = []
|
| 146 |
years = []
|
| 147 |
+
idea,experiment,entities,references = self.get_article_idea_experiment_references_info(article)
|
| 148 |
try:
|
| 149 |
references = json.loads(references)
|
| 150 |
except:
|
|
|
|
| 161 |
# search before
|
| 162 |
while len(idea_chain)<self.max_chain_length:
|
| 163 |
rerank_query = f"{self.topic} {current_title} {current_abstract}"
|
| 164 |
+
citation_paper = self.reader.search_related_paper(current_title,need_reference=False,rerank_query=rerank_query,llm=self.llm,paper_list=idea_papers)
|
| 165 |
if not citation_paper:
|
| 166 |
print(f"failed to find citation paper for {current_title}")
|
| 167 |
break
|
|
|
|
| 169 |
abstract = citation_paper.abstract
|
| 170 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
| 171 |
messages = self.wrap_messages(prompt)
|
| 172 |
+
response = self.get_openai_response(messages)
|
| 173 |
relevant = extract(response,"relevant")
|
| 174 |
if relevant != "0":
|
| 175 |
+
result = self.get_paper_idea_experiment_references_info(citation_paper)
|
| 176 |
if not result:
|
| 177 |
break
|
| 178 |
idea,experiment,entities,_,_ = result
|
|
|
|
| 199 |
references.pop(0)
|
| 200 |
if reference in self.read_papers:
|
| 201 |
continue
|
| 202 |
+
search_paper = self.reader.search(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers)
|
| 203 |
if len(search_paper) > 0:
|
| 204 |
s_p = search_paper[0]
|
| 205 |
if s_p and s_p.title not in self.read_papers:
|
| 206 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
| 207 |
messages = self.wrap_messages(prompt)
|
| 208 |
+
response = self.get_openai_response(messages)
|
| 209 |
relevant = extract(response,"relevant")
|
| 210 |
if relevant != "0" or len(idea_chain) < self.min_chain_length:
|
| 211 |
article = s_p.article
|
|
|
|
| 218 |
|
| 219 |
if not article:
|
| 220 |
rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}"
|
| 221 |
+
search_paper = self.reader.search_related_paper(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers)
|
| 222 |
if not search_paper:
|
| 223 |
print(f"failed to find citation paper for {current_title}")
|
| 224 |
continue
|
|
|
|
| 234 |
if s_p and s_p.title not in self.read_papers:
|
| 235 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
| 236 |
messages = self.wrap_messages(prompt)
|
| 237 |
+
response = self.get_openai_response(messages)
|
| 238 |
relevant = extract(response,"relevant")
|
| 239 |
if relevant == "1" or len(idea_chain) < self.min_chain_length:
|
| 240 |
+
article = s_p.article
|
| 241 |
if not article:
|
| 242 |
continue
|
| 243 |
else:
|
|
|
|
| 251 |
paper_content = self.reader.read_paper_content_with_ref(article)
|
| 252 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
| 253 |
messages = self.wrap_messages(prompt)
|
| 254 |
+
response = self.get_cheap_openai_response(messages)
|
| 255 |
idea = extract(response,"idea")
|
| 256 |
references = extract(response,"references")
|
| 257 |
experiment = extract(response,"experiment")
|
|
|
|
| 278 |
|
| 279 |
prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic)
|
| 280 |
messages = self.wrap_messages(prompt)
|
| 281 |
+
response = self.get_openai_response(messages)
|
| 282 |
trend = extract(response,"trend")
|
| 283 |
|
| 284 |
self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years})
|
|
|
|
| 287 |
<entities> {{cleaned entities}}</entities>
|
| 288 |
"""
|
| 289 |
messages = self.wrap_messages(prompt)
|
| 290 |
+
response = self.get_openai_response(messages)
|
| 291 |
total_entities = extract(response,"entities")
|
| 292 |
bad_case = []
|
| 293 |
prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities)
|
| 294 |
messages = self.wrap_messages(prompt)
|
| 295 |
+
response = self.get_openai_response(messages)
|
| 296 |
future = extract(response,"future")
|
| 297 |
human = extract(response,"human")
|
| 298 |
|
| 299 |
|
| 300 |
prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case)
|
| 301 |
messages = self.wrap_messages(prompt)
|
| 302 |
+
response = self.get_openai_response(messages)
|
| 303 |
method = extract(response,"method")
|
| 304 |
novelty = extract(response,"novelty")
|
| 305 |
motivation = extract(response,"motivation")
|
| 306 |
idea = {"motivation":motivation,"novelty":novelty,"method":method}
|
| 307 |
prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic)
|
| 308 |
messages = self.wrap_messages(prompt)
|
| 309 |
+
response = self.get_openai_response(messages)
|
| 310 |
final_idea = extract(response,"final_idea")
|
| 311 |
|
| 312 |
idea = final_idea
|
app.py
CHANGED
|
@@ -332,7 +332,7 @@ def form_post(topic: str = Form(...)):
|
|
| 332 |
main_llm, cheap_llm = get_llms()
|
| 333 |
deep_research_agent = DeepResearchAgent(llm=main_llm, cheap_llm=cheap_llm, improve_cnt=1, max_chain_length=5, min_chain_length=3, max_chain_numbers=1)
|
| 334 |
print(f"begin to generate idea of topic {topic}")
|
| 335 |
-
idea, related_experiments, entities, idea_chain, ideas, trend, future, human, year =
|
| 336 |
idea_md = markdown.markdown(idea)
|
| 337 |
# 更新每日回复次数
|
| 338 |
reply_count += 1
|
|
|
|
| 332 |
main_llm, cheap_llm = get_llms()
|
| 333 |
deep_research_agent = DeepResearchAgent(llm=main_llm, cheap_llm=cheap_llm, improve_cnt=1, max_chain_length=5, min_chain_length=3, max_chain_numbers=1)
|
| 334 |
print(f"begin to generate idea of topic {topic}")
|
| 335 |
+
idea, related_experiments, entities, idea_chain, ideas, trend, future, human, year = deep_research_agent.generate_idea_with_chain(topic)
|
| 336 |
idea_md = markdown.markdown(idea)
|
| 337 |
# 更新每日回复次数
|
| 338 |
reply_count += 1
|
main.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
from agents import DeepResearchAgent,
|
| 2 |
import asyncio
|
| 3 |
import json
|
| 4 |
import argparse
|
| 5 |
|
|
|
|
| 6 |
if __name__ == '__main__':
|
| 7 |
|
| 8 |
argparser = argparse.ArgumentParser()
|
|
@@ -21,18 +22,12 @@ if __name__ == '__main__':
|
|
| 21 |
topic = args.topic
|
| 22 |
anchor_paper_path = args.anchor_paper_path
|
| 23 |
|
| 24 |
-
|
| 25 |
-
review_agent = ReviewAgent(save_file=args.save_file,llm=main_llm,cheap_llm=cheap_llm)
|
| 26 |
deep_research_agent = DeepResearchAgent(llm=main_llm,cheap_llm=cheap_llm,**vars(args))
|
| 27 |
|
| 28 |
print(f"begin to generate idea and experiment of topic {topic}")
|
| 29 |
-
idea,related_experiments,entities,idea_chain,ideas,trend,future,human,year=
|
| 30 |
-
experiment = asyncio.run(deep_research_agent.generate_experiment(idea,related_experiments,entities))
|
| 31 |
-
|
| 32 |
-
for i in range(args.improve_cnt):
|
| 33 |
-
experiment = asyncio.run(deep_research_agent.improve_experiment(review_agent,idea,experiment,entities))
|
| 34 |
|
| 35 |
print(f"succeed to generate idea and experiment of topic {topic}")
|
| 36 |
-
res = {"idea":idea,"
|
| 37 |
with open("result.json","w") as f:
|
| 38 |
json.dump(res,f)
|
|
|
|
| 1 |
+
from agents import DeepResearchAgent,get_llms
|
| 2 |
import asyncio
|
| 3 |
import json
|
| 4 |
import argparse
|
| 5 |
|
| 6 |
+
|
| 7 |
if __name__ == '__main__':
|
| 8 |
|
| 9 |
argparser = argparse.ArgumentParser()
|
|
|
|
| 22 |
topic = args.topic
|
| 23 |
anchor_paper_path = args.anchor_paper_path
|
| 24 |
|
|
|
|
|
|
|
| 25 |
deep_research_agent = DeepResearchAgent(llm=main_llm,cheap_llm=cheap_llm,**vars(args))
|
| 26 |
|
| 27 |
print(f"begin to generate idea and experiment of topic {topic}")
|
| 28 |
+
idea,related_experiments,entities,idea_chain,ideas,trend,future,human,year= deep_research_agent.generate_idea_with_chain(topic)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
print(f"succeed to generate idea and experiment of topic {topic}")
|
| 31 |
+
res = {"idea":idea,"related_experiments":related_experiments,"entities":entities,"idea_chain":idea_chain,"ideas":ideas,"trend":trend,"future":future,"year":year,"human":human}
|
| 32 |
with open("result.json","w") as f:
|
| 33 |
json.dump(res,f)
|
searcher/sementic_search.py
CHANGED
|
@@ -7,7 +7,7 @@ import time
|
|
| 7 |
import aiohttp
|
| 8 |
import asyncio
|
| 9 |
import numpy as np
|
| 10 |
-
|
| 11 |
|
| 12 |
def get_content_between_a_b(start_tag, end_tag, text):
|
| 13 |
extracted_text = ""
|
|
@@ -31,29 +31,6 @@ def extract(text, type):
|
|
| 31 |
return text
|
| 32 |
else:
|
| 33 |
return ""
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
async def fetch(url):
|
| 37 |
-
await asyncio.sleep(1) # 异步的 sleep 而不是 time.sleep
|
| 38 |
-
try:
|
| 39 |
-
timeout = aiohttp.ClientTimeout(total=120)
|
| 40 |
-
connector = aiohttp.TCPConnector(limit_per_host=10) # 使用连接池
|
| 41 |
-
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
| 42 |
-
async with session.get(url) as response:
|
| 43 |
-
if response.status == 200:
|
| 44 |
-
content = await response.read() # Read the response content as bytes
|
| 45 |
-
return content
|
| 46 |
-
else:
|
| 47 |
-
print(f"Failed to fetch the URL: {url} with status code: {response.status}")
|
| 48 |
-
return None
|
| 49 |
-
except aiohttp.ClientError as e: # 更具体的异常捕获
|
| 50 |
-
print(f"An error occurred while fetching the URL: {url}")
|
| 51 |
-
print(e)
|
| 52 |
-
return None
|
| 53 |
-
except Exception as e:
|
| 54 |
-
print(f"An unexpected error occurred while fetching the URL: {url}")
|
| 55 |
-
print(e)
|
| 56 |
-
return None
|
| 57 |
|
| 58 |
def download(url):
|
| 59 |
try:
|
|
@@ -103,7 +80,7 @@ class SementicSearcher:
|
|
| 103 |
def __init__(self, ban_paper = []) -> None:
|
| 104 |
self.ban_paper = ban_paper
|
| 105 |
|
| 106 |
-
|
| 107 |
publicationDate=None, minCitationCount=0, year=None,
|
| 108 |
publicationTypes=None, fieldsOfStudy=None):
|
| 109 |
url = 'https://api.semanticscholar.org/graph/v1/paper/search'
|
|
@@ -124,7 +101,6 @@ class SementicSearcher:
|
|
| 124 |
# Load the API key from the configuration file
|
| 125 |
api_key = os.environ.get('SEMENTIC_SEARCH_API_KEY',None)
|
| 126 |
headers = {'x-api-key': api_key} if api_key else None
|
| 127 |
-
await asyncio.sleep(0.5)
|
| 128 |
try:
|
| 129 |
filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
|
| 130 |
response = requests.get(url, params=filtered_query_params, headers=headers)
|
|
@@ -135,7 +111,7 @@ class SementicSearcher:
|
|
| 135 |
elif response.status_code == 429:
|
| 136 |
time.sleep(1)
|
| 137 |
print(f"Request failed with status code {response.status_code}: begin to retry")
|
| 138 |
-
return
|
| 139 |
else:
|
| 140 |
print(f"Request failed with status code {response.status_code}: {response.text}")
|
| 141 |
return None
|
|
@@ -145,6 +121,23 @@ class SementicSearcher:
|
|
| 145 |
|
| 146 |
def cal_cosine_similarity(self, vec1, vec2):
|
| 147 |
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
def read_arxiv_from_path(self, pdf_path):
|
| 150 |
def is_pdf(binary_data):
|
|
@@ -163,97 +156,41 @@ class SementicSearcher:
|
|
| 163 |
return None
|
| 164 |
return article_dict
|
| 165 |
|
| 166 |
-
|
| 167 |
paper_content = f"""
|
| 168 |
Title: {paper['title']}
|
| 169 |
Abstract: {paper['abstract']}
|
| 170 |
"""
|
| 171 |
-
paper_embbeding =
|
| 172 |
paper_embbeding = np.array(paper_embbeding)
|
| 173 |
score = self.cal_cosine_similarity(query_embedding,paper_embbeding)
|
| 174 |
return [paper,score]
|
| 175 |
|
| 176 |
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
| 178 |
if len(paper_list) >= 50:
|
| 179 |
-
paper_list = paper_list
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
url = f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}'
|
| 196 |
-
fields = process_fields(fields)
|
| 197 |
-
paper_data_query_params = {'fields': fields}
|
| 198 |
-
try:
|
| 199 |
-
async with aiohttp.ClientSession() as session:
|
| 200 |
-
filtered_query_params = {key: value for key, value in paper_data_query_params.items() if value is not None}
|
| 201 |
-
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
|
| 202 |
-
async with session.get(url, params=filtered_query_params, headers=headers) as response:
|
| 203 |
-
if response.status == 200:
|
| 204 |
-
response_data = await response.json()
|
| 205 |
-
return response_data
|
| 206 |
-
else:
|
| 207 |
-
await asyncio.sleep(0.01)
|
| 208 |
-
print(f"Request failed with status code {response.status}: {await response.text()}")
|
| 209 |
-
return None
|
| 210 |
-
except Exception as e:
|
| 211 |
-
print(f"Failed to get paper details for paper ID: {paper_id}")
|
| 212 |
-
return None
|
| 213 |
-
|
| 214 |
-
async def batch_retrieve_papers_async(self, paper_ids, fields = semantic_fields):
|
| 215 |
-
url = 'https://api.semanticscholar.org/graph/v1/paper/batch'
|
| 216 |
-
paper_data_query_params = {'fields': process_fields(fields)}
|
| 217 |
-
paper_ids_json = {"ids": paper_ids}
|
| 218 |
-
try:
|
| 219 |
-
async with aiohttp.ClientSession() as session:
|
| 220 |
-
filtered_query_params = {key: value for key, value in paper_data_query_params.items() if value is not None}
|
| 221 |
-
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
|
| 222 |
-
async with session.post(url, json=paper_ids_json, params=filtered_query_params, headers=headers) as response:
|
| 223 |
-
if response.status == 200:
|
| 224 |
-
response_data = await response.json()
|
| 225 |
-
return response_data
|
| 226 |
-
else:
|
| 227 |
-
await asyncio.sleep(0.01)
|
| 228 |
-
print(f"Request failed with status code {response.status}: {await response.text()}")
|
| 229 |
-
return None
|
| 230 |
-
except Exception as e:
|
| 231 |
-
print(f"Failed to batch retrieve papers for paper IDs: {paper_ids}")
|
| 232 |
-
return None
|
| 233 |
-
|
| 234 |
-
async def search_paper_from_title_async(self, query,fields = ["title","paperId"]):
|
| 235 |
-
url = 'https://api.semanticscholar.org/graph/v1/paper/search/match'
|
| 236 |
-
fields = process_fields(fields)
|
| 237 |
-
query_params = {'query': query, 'fields': fields}
|
| 238 |
-
try:
|
| 239 |
-
async with aiohttp.ClientSession() as session:
|
| 240 |
-
filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
|
| 241 |
-
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
|
| 242 |
-
async with session.get(url, params=filtered_query_params, headers=headers) as response:
|
| 243 |
-
if response.status == 200:
|
| 244 |
-
response_data = await response.json()
|
| 245 |
-
return response_data
|
| 246 |
-
else:
|
| 247 |
-
await asyncio.sleep(0.01)
|
| 248 |
-
print(f"Request failed with status code {response.status}: {await response.text()}")
|
| 249 |
-
return None
|
| 250 |
-
except Exception as e:
|
| 251 |
-
await asyncio.sleep(0.01)
|
| 252 |
-
print(f"Failed to search paper from title: {query}")
|
| 253 |
-
return None
|
| 254 |
|
| 255 |
|
| 256 |
-
|
| 257 |
if rerank_query:
|
| 258 |
rerank_query_embbeding = llm.get_embbeding(rerank_query)
|
| 259 |
rerank_query_embbeding = np.array(rerank_query_embbeding)
|
|
@@ -270,7 +207,7 @@ Abstract: {paper['abstract']}
|
|
| 270 |
readed_papers = [paper.title for paper in paper_list]
|
| 271 |
|
| 272 |
print(f"Searching for papers related to the query: <{query}>")
|
| 273 |
-
results =
|
| 274 |
if not results or "data" not in results:
|
| 275 |
return []
|
| 276 |
|
|
@@ -293,8 +230,7 @@ Abstract: {paper['abstract']}
|
|
| 293 |
paper_candidates = results
|
| 294 |
|
| 295 |
if llm and rerank_query:
|
| 296 |
-
paper_candidates =
|
| 297 |
-
paper_candidates = [paper[0] for paper in paper_candidates if paper]
|
| 298 |
|
| 299 |
if need_download:
|
| 300 |
for result in paper_candidates:
|
|
@@ -326,10 +262,10 @@ Abstract: {paper['abstract']}
|
|
| 326 |
break
|
| 327 |
return final_results
|
| 328 |
|
| 329 |
-
|
| 330 |
-
print(f"Searching for the related papers of <{title}
|
| 331 |
fileds = ["title","abstract","citations.title","citations.abstract","citations.citationCount","references.title","references.abstract","references.citationCount","citations.isOpenAccess","citations.openAccessPdf","references.isOpenAccess","references.openAccessPdf","citations.year","references.year"]
|
| 332 |
-
results =
|
| 333 |
related_papers = []
|
| 334 |
related_papers_title = []
|
| 335 |
if not results or "data" not in results:
|
|
@@ -367,8 +303,7 @@ Abstract: {paper['abstract']}
|
|
| 367 |
if rerank_query and llm:
|
| 368 |
rerank_query_embbeding = llm.get_embbeding(rerank_query)
|
| 369 |
rerank_query_embbeding = np.array(rerank_query_embbeding)
|
| 370 |
-
related_papers =
|
| 371 |
-
related_papers = [paper[0] for paper in related_papers]
|
| 372 |
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
|
| 373 |
else:
|
| 374 |
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
|
|
@@ -385,13 +320,6 @@ Abstract: {paper['abstract']}
|
|
| 385 |
return result
|
| 386 |
return None
|
| 387 |
|
| 388 |
-
|
| 389 |
-
async def download_pdf_async(self, pdf_link):
|
| 390 |
-
content = await fetch(pdf_link)
|
| 391 |
-
if not content:
|
| 392 |
-
return None
|
| 393 |
-
else:
|
| 394 |
-
return content
|
| 395 |
|
| 396 |
def download_pdf(self, pdf_link):
|
| 397 |
content = download(pdf_link)
|
|
|
|
| 7 |
import aiohttp
|
| 8 |
import asyncio
|
| 9 |
import numpy as np
|
| 10 |
+
import random
|
| 11 |
|
| 12 |
def get_content_between_a_b(start_tag, end_tag, text):
|
| 13 |
extracted_text = ""
|
|
|
|
| 31 |
return text
|
| 32 |
else:
|
| 33 |
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def download(url):
|
| 36 |
try:
|
|
|
|
| 80 |
def __init__(self, ban_paper = []) -> None:
|
| 81 |
self.ban_paper = ban_paper
|
| 82 |
|
| 83 |
+
def search_papers(self, query, limit=5, offset=0, fields=["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citations.title","citations.abstract","citations.isOpenAccess","citations.openAccessPdf","citations.citationCount","citationCount","citations.year"],
|
| 84 |
publicationDate=None, minCitationCount=0, year=None,
|
| 85 |
publicationTypes=None, fieldsOfStudy=None):
|
| 86 |
url = 'https://api.semanticscholar.org/graph/v1/paper/search'
|
|
|
|
| 101 |
# Load the API key from the configuration file
|
| 102 |
api_key = os.environ.get('SEMENTIC_SEARCH_API_KEY',None)
|
| 103 |
headers = {'x-api-key': api_key} if api_key else None
|
|
|
|
| 104 |
try:
|
| 105 |
filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
|
| 106 |
response = requests.get(url, params=filtered_query_params, headers=headers)
|
|
|
|
| 111 |
elif response.status_code == 429:
|
| 112 |
time.sleep(1)
|
| 113 |
print(f"Request failed with status code {response.status_code}: begin to retry")
|
| 114 |
+
return self.search_papers(query, limit, offset, fields, publicationDate, minCitationCount, year, publicationTypes, fieldsOfStudy)
|
| 115 |
else:
|
| 116 |
print(f"Request failed with status code {response.status_code}: {response.text}")
|
| 117 |
return None
|
|
|
|
| 121 |
|
| 122 |
def cal_cosine_similarity(self, vec1, vec2):
|
| 123 |
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
| 124 |
+
|
| 125 |
+
def cal_cosine_similarity_matric(self,matric1, matric2):
|
| 126 |
+
if isinstance(matric1, list):
|
| 127 |
+
matric1 = np.array(matric1)
|
| 128 |
+
if isinstance(matric2, list):
|
| 129 |
+
matric2 = np.array(matric2)
|
| 130 |
+
if len(matric1.shape) == 1:
|
| 131 |
+
matric1 = matric1.reshape(1, -1)
|
| 132 |
+
if len(matric2.shape) == 1:
|
| 133 |
+
matric2 = matric2.reshape(1, -1)
|
| 134 |
+
dot_product = np.dot(matric1, matric2.T)
|
| 135 |
+
norm1 = np.linalg.norm(matric1, axis=1)
|
| 136 |
+
norm2 = np.linalg.norm(matric2, axis=1)
|
| 137 |
+
|
| 138 |
+
cos_sim = dot_product / np.outer(norm1, norm2)
|
| 139 |
+
scores = cos_sim.flatten()
|
| 140 |
+
return scores.tolist()
|
| 141 |
|
| 142 |
def read_arxiv_from_path(self, pdf_path):
|
| 143 |
def is_pdf(binary_data):
|
|
|
|
| 156 |
return None
|
| 157 |
return article_dict
|
| 158 |
|
| 159 |
+
def get_paper_embbeding_and_score(self,query_embedding, paper,llm):
|
| 160 |
paper_content = f"""
|
| 161 |
Title: {paper['title']}
|
| 162 |
Abstract: {paper['abstract']}
|
| 163 |
"""
|
| 164 |
+
paper_embbeding = llm.get_embbeding(paper_content)
|
| 165 |
paper_embbeding = np.array(paper_embbeding)
|
| 166 |
score = self.cal_cosine_similarity(query_embedding,paper_embbeding)
|
| 167 |
return [paper,score]
|
| 168 |
|
| 169 |
|
| 170 |
+
def rerank_papers(self, query_embedding, paper_list,llm):
|
| 171 |
+
if len(paper_list) == 0:
|
| 172 |
+
return []
|
| 173 |
+
paper_list = [paper for paper in paper_list if paper]
|
| 174 |
if len(paper_list) >= 50:
|
| 175 |
+
paper_list = random.sample(paper_list,50)
|
| 176 |
+
paper_contents = []
|
| 177 |
+
for paper in paper_list:
|
| 178 |
+
paper_content = f"""
|
| 179 |
+
Title: {paper['title']}
|
| 180 |
+
Abstract: {paper['abstract']}
|
| 181 |
+
"""
|
| 182 |
+
paper_contents.append(paper_content)
|
| 183 |
+
paper_contents_embbeding = llm.get_embbeding(paper_contents)
|
| 184 |
+
paper_contents_embbeding = np.array(paper_contents_embbeding)
|
| 185 |
+
scores = self.cal_cosine_similarity_matric(query_embedding,paper_contents_embbeding)
|
| 186 |
+
|
| 187 |
+
# 根据score对paper_list进行排序
|
| 188 |
+
paper_list = sorted(zip(paper_list,scores),key = lambda x: x[1],reverse = True)
|
| 189 |
+
paper_list = [paper[0] for paper in paper_list]
|
| 190 |
+
return paper_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
|
| 193 |
+
def search(self,query,max_results = 5 ,paper_list = None ,rerank_query = None,llm = None,year = None,publicationDate = None,need_download = True,fields = ["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citationCount"]):
|
| 194 |
if rerank_query:
|
| 195 |
rerank_query_embbeding = llm.get_embbeding(rerank_query)
|
| 196 |
rerank_query_embbeding = np.array(rerank_query_embbeding)
|
|
|
|
| 207 |
readed_papers = [paper.title for paper in paper_list]
|
| 208 |
|
| 209 |
print(f"Searching for papers related to the query: <{query}>")
|
| 210 |
+
results = self.search_papers(query,limit = 10 * max_results,year=year,publicationDate = publicationDate,fields = fields)
|
| 211 |
if not results or "data" not in results:
|
| 212 |
return []
|
| 213 |
|
|
|
|
| 230 |
paper_candidates = results
|
| 231 |
|
| 232 |
if llm and rerank_query:
|
| 233 |
+
paper_candidates = self.rerank_papers(rerank_query_embbeding, paper_candidates,llm)
|
|
|
|
| 234 |
|
| 235 |
if need_download:
|
| 236 |
for result in paper_candidates:
|
|
|
|
| 262 |
break
|
| 263 |
return final_results
|
| 264 |
|
| 265 |
+
def search_related_paper(self,title,need_citation = True,need_reference = True,rerank_query = None,llm = None,paper_list = []):
|
| 266 |
+
print(f"Searching for the related papers of <{title}>, need_citation: {need_citation}, need_reference: {need_reference}")
|
| 267 |
fileds = ["title","abstract","citations.title","citations.abstract","citations.citationCount","references.title","references.abstract","references.citationCount","citations.isOpenAccess","citations.openAccessPdf","references.isOpenAccess","references.openAccessPdf","citations.year","references.year"]
|
| 268 |
+
results = self.search_papers(title,limit = 3,fields=fileds)
|
| 269 |
related_papers = []
|
| 270 |
related_papers_title = []
|
| 271 |
if not results or "data" not in results:
|
|
|
|
| 303 |
if rerank_query and llm:
|
| 304 |
rerank_query_embbeding = llm.get_embbeding(rerank_query)
|
| 305 |
rerank_query_embbeding = np.array(rerank_query_embbeding)
|
| 306 |
+
related_papers = self.rerank_papers(rerank_query_embbeding, related_papers,llm)
|
|
|
|
| 307 |
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
|
| 308 |
else:
|
| 309 |
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
|
|
|
|
| 320 |
return result
|
| 321 |
return None
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
def download_pdf(self, pdf_link):
|
| 325 |
content = download(pdf_link)
|