Spaces:
Runtime error
Runtime error
| import json | |
| import time | |
| from searcher import Result,SementicSearcher | |
| from LLM import openai_llm | |
| from prompts import * | |
| from utils import extract | |
| def get_llm(model = "gpt4o-0513"): | |
| return openai_llm(model) | |
| def get_llms(): | |
| main_llm = get_llm("gpt4o-0513") | |
| cheap_llm = get_llm("gpt-4o-mini") | |
| return main_llm,cheap_llm | |
| def judge_idea(i,j,idea0,idea1,topic,llm): | |
| prompt = get_judge_idea_all_prompt(idea0,idea1,topic) | |
| messages = [{"role":"user","content":prompt}] | |
| response = llm.response(messages) | |
| novelty = extract(response,"novelty") | |
| relevance = extract(response,"relevance") | |
| significance = extract(response,"significance") | |
| clarity = extract(response,"clarity") | |
| feasibility = extract(response,"feasibility") | |
| effectiveness = extract(response,"effectiveness") | |
| return i,j,novelty,relevance,significance,clarity,feasibility,effectiveness | |
| class DeepResearchAgent: | |
| def __init__(self,llm = None,cheap_llm=None,publicationData = None,ban_paper = [],**kwargs) -> None: | |
| self.reader = SementicSearcher(ban_paper = ban_paper) | |
| self.begin_time = time.time() | |
| self.llm = llm | |
| self.cheap_llm = cheap_llm | |
| self.read_papers = set() | |
| self.paper_storage = [] | |
| self.paper_info_for_refine_experiment = [] | |
| self.search_qeuries = [] | |
| self.deep_research_chains = [] | |
| self.deep_ideas = [] | |
| self.check_novel_results = [] | |
| self.score_results = [] | |
| self.topic =None | |
| self.publicationData = publicationData | |
| self.improve_cnt = kwargs.get("improve_cnt",1) | |
| self.max_chain_length = kwargs.get("max_chain_length",5) | |
| self.min_chain_length = kwargs.get("min_chain_length",3) | |
| self.max_chain_numbers = kwargs.get("max_chain_numbers",10) | |
| def wrap_messages(self,prompt): | |
| return [{"role":"user","content":prompt}] | |
| def get_openai_response(self,messages): | |
| return self.llm.response(messages) | |
| def get_cheap_openai_response(self,messages): | |
| return self.cheap_llm.response(messages,max_tokens = 16000) | |
| def get_search_query(self,topic = None,query=None): | |
| prompt = get_deep_search_query_prompt(topic,query) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| search_query = extract(response,"queries") | |
| try: | |
| search_query = json.loads(search_query) | |
| self.search_qeuries.append({"query":query,"search_query":search_query}) | |
| except: | |
| search_query = [query] | |
| return search_query | |
| def generate_idea_with_chain(self,topic): | |
| self.topic = topic | |
| print(f"begin to generate search query for {topic}") | |
| search_query = self.get_search_query(topic=topic) | |
| papers = [] | |
| for query in search_query: | |
| failed_query = [] | |
| current_papers = [] | |
| cnt = 0 | |
| while len(current_papers) == 0 and cnt < 10: | |
| paper = self.reader.search(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData) | |
| if paper and len(paper) > 0 and paper[0]: | |
| self.read_papers.add(paper[0].title) | |
| current_papers.append(paper[0]) | |
| else: | |
| failed_query.append(query) | |
| prompt = get_deep_rewrite_query_prompt(failed_query,topic) | |
| messages = self.wrap_messages(prompt) | |
| new_query = self.get_openai_response(messages) | |
| new_query = extract(new_query,"query") | |
| print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.") | |
| query = new_query | |
| cnt += 1 | |
| papers.extend(current_papers) | |
| if len(papers) >= self.max_chain_numbers: | |
| break | |
| if len(papers) == 0: | |
| print(f"failed to generate idea {topic}") | |
| return None,None,None,None,None,None,None,None,None | |
| idea,idea_chain,experiment,entities,trend,future,human,year = self.deep_research_paper_with_chain(papers[0]) | |
| print(f"successfully generated idea") | |
| return idea,experiment,entities,idea_chain,idea,trend,future,human,year | |
| def get_paper_idea_experiment_references_info(self,paper): | |
| article = paper.article | |
| if not article: | |
| return None | |
| paper_content = self.reader.read_paper_content(article) | |
| prompt = get_deep_reference_prompt(paper_content,self.topic) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_cheap_openai_response(messages) | |
| entities = extract(response,"entities") | |
| idea = extract(response,"idea") | |
| experiment = extract(response,"experiment") | |
| references = extract(response,"references") | |
| return idea,experiment,entities,references,paper.title | |
| def get_article_idea_experiment_references_info(self,article): | |
| paper_content = self.reader.read_paper_content_with_ref(article) | |
| prompt = get_deep_reference_prompt(paper_content,self.topic) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_cheap_openai_response(messages) | |
| entities = extract(response,"entities") | |
| idea = extract(response,"idea") | |
| experiment = extract(response,"experiment") | |
| references = extract(response,"references") | |
| return idea,experiment,entities,references | |
| def deep_research_paper_with_chain(self,paper:Result): | |
| print(f"begin to deep research paper {paper.title}") | |
| article = paper.article | |
| if not article: | |
| print(f"failed to deep research paper {paper.title}") | |
| return None | |
| idea_chain = [] | |
| idea_papers = [] | |
| experiments = [] | |
| total_entities = [] | |
| years = [] | |
| idea,experiment,entities,references = self.get_article_idea_experiment_references_info(article) | |
| try: | |
| references = json.loads(references) | |
| except: | |
| references = [] | |
| total_entities.append(entities) | |
| idea_chain.append(idea) | |
| idea_papers.append(paper.title) | |
| experiments.append(experiment) | |
| years.append(paper.year) | |
| current_title = paper.title | |
| current_abstract = paper.abstract | |
| # search before | |
| while len(idea_chain)<self.max_chain_length: | |
| rerank_query = f"{self.topic} {current_title} {current_abstract}" | |
| citation_paper = self.reader.search_related_paper(current_title,need_reference=False,rerank_query=rerank_query,llm=self.llm,paper_list=idea_papers) | |
| if not citation_paper: | |
| print(f"failed to find citation paper for {current_title}") | |
| break | |
| title = citation_paper.title | |
| abstract = citation_paper.abstract | |
| prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| relevant = extract(response,"relevant") | |
| if relevant != "0": | |
| result = self.get_paper_idea_experiment_references_info(citation_paper) | |
| if not result: | |
| break | |
| idea,experiment,entities,_,_ = result | |
| idea_chain.append(idea) | |
| experiments.append(experiment) | |
| total_entities.append(entities) | |
| idea_papers.append(citation_paper.title) | |
| years.append(citation_paper.year) | |
| current_title = citation_paper.title | |
| current_abstract = citation_paper.abstract | |
| else: | |
| print(f"the paper {title} is not relevant") | |
| break | |
| current_title = paper.title | |
| current_abstract = paper.abstract | |
| # search after | |
| while len(idea_chain) < self.max_chain_length and len(references) > 0: | |
| search_paper = [] | |
| article = None | |
| print(f"The references find:{references}") | |
| while len(references) > 0 and len(search_paper) == 0: | |
| reference = references[0] | |
| references.pop(0) | |
| if reference in self.read_papers: | |
| continue | |
| search_paper = self.reader.search(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers) | |
| if len(search_paper) > 0: | |
| s_p = search_paper[0] | |
| if s_p and s_p.title not in self.read_papers: | |
| prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| relevant = extract(response,"relevant") | |
| if relevant != "0" or len(idea_chain) < self.min_chain_length: | |
| article = s_p.article | |
| if article: | |
| cite_paper = s_p | |
| break | |
| else: | |
| print(f"the paper {s_p.title} is not relevant") | |
| search_paper = [] | |
| if not article: | |
| rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}" | |
| search_paper = self.reader.search_related_paper(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers) | |
| if not search_paper: | |
| print(f"failed to find citation paper for {current_title}") | |
| continue | |
| s_p = search_paper | |
| if len(idea_chain) < self.min_chain_length: | |
| article = s_p.article | |
| if not article: | |
| continue | |
| else: | |
| cite_paper = s_p | |
| break | |
| else: | |
| if s_p and s_p.title not in self.read_papers: | |
| prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| relevant = extract(response,"relevant") | |
| if relevant == "1" or len(idea_chain) < self.min_chain_length: | |
| article = s_p.article | |
| if not article: | |
| continue | |
| else: | |
| cite_paper = s_p | |
| break | |
| if not article: | |
| print(f"failed to find citation paper for {current_title}") | |
| continue | |
| print("find the citation paper, begin to deep research") | |
| paper_content = self.reader.read_paper_content_with_ref(article) | |
| prompt = get_deep_reference_prompt(paper_content,self.topic) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_cheap_openai_response(messages) | |
| idea = extract(response,"idea") | |
| references = extract(response,"references") | |
| experiment = extract(response,"experiment") | |
| entities = extract(response,"entities") | |
| try: | |
| references = json.loads(references) | |
| except: | |
| references = [] | |
| current_title = cite_paper.title | |
| current_abstract = cite_paper.abstract | |
| years = [cite_paper.year] + years | |
| idea_chain = [idea] + idea_chain | |
| idea_papers = [cite_paper.title] + idea_papers | |
| experiments = [experiment] + experiments | |
| total_entities = [entities] + total_entities | |
| if len(idea_chain) >= self.min_chain_length: | |
| if cite_paper.citations_conut > 1000: | |
| break | |
| print("successfully generate idea chain") | |
| idea_chains = "" | |
| for i,idea,title in zip(range(len(idea_chain)),idea_chain,idea_papers): | |
| idea_chains += f"{i}.Paper:{title} idea:{idea}\n \n" | |
| prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| trend = extract(response,"trend") | |
| self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years}) | |
| prompt = f"""The current research topic is: {self.topic}. Please help me summarize and refine the following entities by merging, simplifying, or deleting them : {total_entities} | |
| Please output strictly in the following format: | |
| <entities> {{cleaned entities}}</entities> | |
| """ | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| total_entities = extract(response,"entities") | |
| bad_case = [] | |
| prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| future = extract(response,"future") | |
| human = extract(response,"human") | |
| prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| method = extract(response,"method") | |
| novelty = extract(response,"novelty") | |
| motivation = extract(response,"motivation") | |
| idea = {"motivation":motivation,"novelty":novelty,"method":method} | |
| prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic) | |
| messages = self.wrap_messages(prompt) | |
| response = self.get_openai_response(messages) | |
| final_idea = extract(response,"final_idea") | |
| idea = final_idea | |
| self.deep_ideas.append(idea) | |
| print(f"successfully deep research paper {paper.title}") | |
| return idea,idea_chains,trend,experiments,total_entities,future,human,years | |
| if __name__ == "__main__": | |
| reader = SementicSearcher() | |