| import json | |
| from utils.paper_retriever import RetrieverFactory | |
| from utils.llms_api import APIHelper | |
| from utils.header import ConfigReader | |
| from utils.hash import check_env, check_embedding | |
| from generator import IdeaGenerator | |
| import functools | |
| class Backend(object): | |
| def __init__(self) -> None: | |
| CONFIG_PATH = "./configs/datasets.yaml" | |
| EXAMPLE_PATH = "./assets/data/example.json" | |
| USE_INSPIRATION = True | |
| BRAINSTORM_MODE = "mode_c" | |
| self.config = ConfigReader.load(CONFIG_PATH) | |
| check_env() | |
| check_embedding(self.config.DEFAULT.embedding) | |
| RETRIEVER_NAME = self.config.RETRIEVE.retriever_name | |
| self.api_helper = APIHelper(self.config) | |
| self.retriever_factory = ( | |
| RetrieverFactory.get_retriever_factory().create_retriever( | |
| RETRIEVER_NAME, self.config | |
| ) | |
| ) | |
| self.idea_generator = IdeaGenerator(self.config, None) | |
| self.use_inspiration = USE_INSPIRATION | |
| self.brainstorm_mode = BRAINSTORM_MODE | |
| self.examples = self.load_examples(EXAMPLE_PATH) | |
| def load_examples(self, path): | |
| try: | |
| with open(path, "r") as f: | |
| data = json.load(f) | |
| return data | |
| except (FileNotFoundError, json.JSONDecodeError) as e: | |
| print(f"Error loading examples from {path}: {e}") | |
| return [] | |
| def background2entities_callback(self, background): | |
| return self.api_helper.generate_entity_list(background) | |
| def background2expandedbackground_callback(self, background, entities): | |
| keywords_str = functools.reduce(lambda x, y: f"{x}, {y}", entities) | |
| expanded_background = self.api_helper.expand_background(background, keywords_str) | |
| return expanded_background | |
| def background2brainstorm_callback(self, expanded_background): | |
| return self.api_helper.generate_brainstorm(expanded_background) | |
| def brainstorm2entities_callback(self, brainstorm, entities): | |
| entities_bs = self.api_helper.generate_entity_list(brainstorm, 10) | |
| entities_all = list(set(entities) | set(entities_bs)) | |
| return entities_all | |
| def upload_json_callback(self, input): | |
| with open(input, "r") as json_file: | |
| contents = json_file.read() | |
| json_contents = json.loads(contents) | |
| return [json_contents["background"], contents] | |
| def entities2literature_callback(self, expanded_background, entities): | |
| result = self.retriever_factory.retrieve( | |
| expanded_background, entities, need_evaluate=False, target_paper_id_list=[] | |
| ) | |
| res = [] | |
| for i, p in enumerate(result["related_paper"]): | |
| res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.') | |
| return res, result["related_paper"] | |
| def literature2initial_ideas_callback( | |
| self, expanded_background, brainstorms, retrieved_literature | |
| ): | |
| self.idea_generator.paper_list = retrieved_literature | |
| self.idea_generator.brainstorm = brainstorms | |
| _, _, inspirations, initial_ideas, idea_filtered, final_ideas = ( | |
| self.idea_generator.generate_ins_bs(expanded_background) | |
| ) | |
| return idea_filtered, final_ideas | |
| def initial2final_callback(self, initial_ideas, final_ideas): | |
| return final_ideas | |
| def get_demo_i(self, i): | |
| if 0 <= i < len(self.examples): | |
| return self.examples[i].get("background", "Background not found.") | |
| else: | |
| return "Example not found. Please select a valid index." | |