Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Form, Request,Response, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from jinja2 import Template | |
| import markdown | |
| import time | |
| from datetime import datetime, timedelta | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from agents import DeepResearchAgent, get_llms | |
| import hashlib | |
| import threading | |
| import logging | |
| from queue import Queue | |
| import json | |
| from collections import Counter | |
| lock = threading.Lock() | |
| app = FastAPI() | |
| current_user = None | |
| # 每日最大回复次数 | |
| MAX_REPLIES_PER_DAY = 500 | |
| # 当日回复次数计数器 | |
| reply_count = 0 | |
| # 启动时设置计数器重置 | |
| last_reset_time = datetime.now() | |
| # HTML模板 | |
| html_template = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>CoI Agent online demo 😊</title> | |
| <style> | |
| body { | |
| font-family: 'Arial', sans-serif; | |
| background-color: #f4f4f9; | |
| margin: 0; | |
| padding: 0; | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| min-height: 100vh; | |
| } | |
| .container { | |
| width: 95%; | |
| max-width: 1200px; | |
| background-color: #fff; | |
| padding: 2rem; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
| } | |
| h1 { | |
| font-size: 1.5rem; | |
| margin-bottom: 1.5rem; | |
| color: #333; | |
| text-align: center; | |
| } | |
| form { | |
| margin-bottom: 1.5rem; | |
| } | |
| .form-group { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-bottom: 1.5rem; | |
| } | |
| .form-group label { | |
| flex: 0; | |
| font-size: 1 rem; /* 增大字体 */ | |
| color: #333; | |
| margin-right: 0.5rem; | |
| background-color: #f0f8ff; /* 气泡背景颜色 */ | |
| padding: 0.5rem 1rem; /* 气泡内边距 */ | |
| border-radius: 10px; /* 气泡圆角 */ | |
| text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.1); /* 艺术字效果 */ | |
| font-family: 'Times new roman', cursive, sans-serif; /* 艺术字体 */ | |
| box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); /* 气泡阴影 */ | |
| } | |
| .form-group input { | |
| flex: 4; | |
| padding: 0.6rem; | |
| font-size: 1rem; | |
| border: 1px solid #ccc; | |
| border-radius: 5px; | |
| margin-left: 1rem; | |
| } | |
| .form-group button { | |
| flex: 0; | |
| padding: 0.6rem 1rem; | |
| font-size: 1rem; | |
| background-color: #F2A582; | |
| color: #fff; | |
| border: none; | |
| border-radius: 5px; | |
| cursor: pointer; | |
| transition: background-color 0.3s ease; | |
| margin-left: 1rem; | |
| } | |
| .form-group button:hover { | |
| background-color: #0056b3; | |
| } | |
| .loading, | |
| .time-box, | |
| .counter-box, | |
| .result, | |
| .error { | |
| margin-top: 1.5rem; | |
| } | |
| .loading { | |
| font-size: 1.2rem; | |
| color: #007bff; | |
| animation: fadeIn 0.5s ease-in-out; | |
| text-align: center; | |
| display: none; | |
| } | |
| .time-counter-container { | |
| display: flex; | |
| justify-content: space-between; | |
| } | |
| .time-box, | |
| .counter-box { | |
| display: inline-block; | |
| padding: 0.5rem 1rem; | |
| background-color: #e9ecef; | |
| border-radius: 10px; | |
| box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
| font-size: 0.9rem; | |
| margin: 0.5rem; | |
| flex: 1; | |
| text-align: center; | |
| } | |
| .result { | |
| display: flex; | |
| justify-content: space-between; | |
| flex-wrap: wrap; | |
| } | |
| .result .box { | |
| flex: 1; | |
| margin: 0.5rem; | |
| padding: 0.5rem; | |
| background-color:#fafafa; | |
| color:#4a4a4a; | |
| word-wrap: break-word; | |
| height: 500px; | |
| overflow-y: auto; | |
| font-size: 0.9rem; | |
| font-family: "Times New Roman", Times, serif; | |
| line-height: 1.8; | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
| border-radius: 10px; | |
| border: 1px solid #ddd; | |
| } | |
| .error .box { | |
| width: 95%; | |
| padding: 1rem; | |
| background-color: #f8d7da; | |
| color: #721c24; | |
| border-radius: 10px; | |
| box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
| word-wrap: break-word; | |
| margin-left: 1rem; | |
| } | |
| h2 { | |
| font-size: 1.3rem; | |
| margin-bottom: 1rem; | |
| color: #333; | |
| } | |
| @keyframes fadeIn { | |
| from { opacity: 0; } | |
| to { opacity: 1; } | |
| } | |
| .progress-bar-container { | |
| width: 99%; | |
| background-color: #e9ecef; | |
| border-radius: 10px; | |
| overflow: hidden; | |
| margin: 0.5rem auto; /* 水平居中 */ | |
| box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
| } | |
| .progress-bar { | |
| height: 20px; | |
| background-color: #727372; | |
| width: 0%; | |
| transition: width 0.1s ease; | |
| } | |
| .example-container { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-bottom: 1.5rem; | |
| height: 50px; | |
| margin-left: 1rem; | |
| } | |
| .example-label { | |
| flex: 0.7; | |
| font-size: 1rem; /* 修正了 font-size 的语法 */ | |
| color: purple; /* 将字体颜色改为紫色 */ | |
| text-align: center; | |
| margin-right: 0rem; | |
| padding: 0.5rem 0.2rem; | |
| background-color: #f0f8ff; | |
| border: 1px solid white; /* 添加白色边框 */ | |
| border-radius: 10px; | |
| box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
| font-family: 'Times New Roman', cursive, sans-serif; | |
| text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.1); | |
| position: relative; /* 为了放置表格状符号 */ | |
| display: flex; | |
| align-items: center; | |
| } | |
| .example-label::before { | |
| content: "📋"; /* 表格状符号 */ | |
| margin-right: 0.5rem; /* 符号和文本之间的间距 */ | |
| font-size: 1.2rem; | |
| color: purple; /* 符号颜色也改成紫色 */ | |
| } | |
| .example-topics { | |
| flex: 6; | |
| display: flex; | |
| min-width: 80%; | |
| justify-content: space-around; | |
| } | |
| .example-topics button { | |
| padding: 0.5rem 1rem; | |
| font-size: 0.9rem; | |
| background-color: #f0f8ff; /* 橙色 */ | |
| color: #000000; | |
| border: none; | |
| border-radius: 5px; | |
| cursor: pointer; | |
| margin: 0.3rem; | |
| transition: background-color 0.3s ease; | |
| height: 50px; | |
| min-width: 320px; | |
| border: 1px solid white; /* 添加白色边框 */ | |
| box-shadow: 0 0 5px rgba(0, 0, 0, 0.1); | |
| font-family: 'Times New Roman', cursive, sans-serif; | |
| } | |
| .example-topics button:hover { | |
| background-color: #e0e0e0; | |
| } | |
| </style> | |
| <script> | |
| let startTime = 0; | |
| let intervalId = null; | |
| let progressIntervalId = null; | |
| let maxTime = 180; // 最大时间180秒 | |
| function showLoading() { | |
| document.getElementById("loading").style.display = "block"; | |
| document.getElementById("submit-btn").disabled = true; | |
| startTime = Date.now(); | |
| intervalId = setInterval(updateTime, 100); | |
| progressIntervalId = setInterval(updateProgressBar, 100); | |
| // 隐藏错误消息 | |
| const errorBox = document.querySelector(".error"); | |
| if (errorBox) { | |
| errorBox.style.display = "none"; | |
| } | |
| } | |
| function hideLoading() { | |
| document.getElementById("loading").style.display = "none"; | |
| document.getElementById("submit-btn").disabled = false; | |
| if (intervalId) { | |
| clearInterval(intervalId); | |
| intervalId = null; | |
| } | |
| if (progressIntervalId) { | |
| clearInterval(progressIntervalId); | |
| progressIntervalId = null; | |
| } | |
| updateProgressBar(100); // 立即更新进度条至100% | |
| } | |
| function updateTime() { | |
| const now = Date.now(); | |
| const elapsed = ((now - startTime) / 1000).toFixed(2); | |
| document.getElementById("time-taken").innerText = `Time Taken: ${elapsed} s`; | |
| } | |
| function updateProgressBar(percentage = null) { | |
| const progressBar = document.getElementById("progress-bar"); | |
| if (percentage !== null) { | |
| progressBar.style.width = `${percentage}%`; | |
| } else { | |
| const now = Date.now(); | |
| const elapsed = (now - startTime) / 1000; | |
| const progress = Math.min((elapsed / maxTime) * 60, 97); | |
| progressBar.style.width = `${progress}%`; | |
| } | |
| } | |
| function fillTopic(topic) { | |
| document.getElementById("topic").value = topic; | |
| } | |
| </script> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin-bottom: 1.5rem;"> | |
| <div> | |
| <h1 >Chain-of-Ideas Agent: Revolutionizing Research Via Novel Idea Development with LLM Agents</h1> | |
| <h3 style="margin: 0;">If this demo pleases you, please give us a star ⭐ on Github or 💖 on this space.</h3> | |
| <h5 style="margin: 0;"> | |
| We only provide a simplified version here, and the number of replies is limited daily. <br> | |
| If you want to experience the full version, please go to our GitHub repository | |
| <a href="https://github.com/DAMO-NLP-SG/CoI-Agent" color="blue""> https://github.com/DAMO-NLP-SG/CoI-Agent </a> <br> | |
| Due to the instability of PDF downloads, the waiting time may be longer, thank you for your patience. | |
| </h5> | |
| </div> | |
| </div> | |
| <div class="time-counter-container"> | |
| <div id="time-taken" class="time-box">Time Taken: {{ time_taken }} seconds</div> | |
| <div class="counter-box">Today's Replies: {{ reply_count }}/500 </div> | |
| </div> | |
| <div class="progress-bar-container"> | |
| <div id="progress-bar" class="progress-bar"></div> | |
| </div> | |
| <div class="result"> | |
| <div class="box"> | |
| <div>{{ idea | safe }}</div> | |
| </div> | |
| </div> | |
| <form action="/" method="post" onsubmit="showLoading()"> | |
| <div class="form-group"> | |
| <input type="text" id="topic" name="topic" placeholder="Enter your topic"> | |
| <input type="hidden" id="user_id" name="user_id"> | |
| <input type="hidden" id="state" name="state"> | |
| <button type="submit" id="submit-btn">{{ button_text }}</button> | |
| </div> | |
| </form> | |
| <div class="example-container"> | |
| <div class="example-label">Example:</div> | |
| <div class="example-topics"> | |
| <button onclick="fillTopic('Realistic Image Synthesis in Medical Imaging')">Realistic Image Synthesis in Medical Imaging</button> | |
| <button onclick="fillTopic('Using diffusion to generate road layout')">Using diffusion to generate road layout</button> | |
| <button onclick="fillTopic('Using LLM-based agent to generate research ideas')">Using LLM-based agent to generate research ideas</button> | |
| </div> | |
| </div> | |
| <div id="loading" class="loading">{{ loading_text }}</div> | |
| {% if error %} | |
| <div class="error"> | |
| <div class="box"> | |
| <h2>Error</h2> | |
| <div>{{ error }}</div> | |
| </div> | |
| </div> | |
| {% endif %} | |
| </div> | |
| <script> | |
| {{ script}} | |
| </script> | |
| <script> | |
| const socket = new WebSocket("ws://localhost:7860/ws"); | |
| socket.addEventListener('open', function (event) { | |
| const userId = document.getElementById("user_id").value; | |
| socket.send(JSON.stringify({ action: "connect", user_id: userId })); | |
| }); | |
| window.addEventListener("beforeunload", function (event) { | |
| const userId = document.getElementById("user_id").value; | |
| socket.send(JSON.stringify({ action: "disconnect", user_id: userId })); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| def generate_user_id(ip: str) -> str: | |
| # 使用哈希函数生成用户 ID | |
| return hashlib.md5(ip.encode()).hexdigest() | |
| # 重置每日计数器 | |
| def reset_counter(): | |
| global reply_count | |
| reply_count = 0 | |
| queue.queue.clear() | |
| # 设置定时任务每天0点重置计数器 | |
| scheduler = BackgroundScheduler() | |
| scheduler.add_job(reset_counter, 'cron', hour=0, minute=0) | |
| scheduler.start() | |
| def fix_markdown(text): | |
| lines = text.split('\n') | |
| # Initialize the result list | |
| result = [] | |
| # Iterate through the lines | |
| for i, line in enumerate(lines): | |
| # Check if the current line starts with a numbered list item | |
| numbers = ['1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.', '10.', '11.', '12.', '13.', '14.', '15.', '16.', '17.', '18.', '19.', '20.','21.','22.','23.','24.','25.','26.','27.','28.','29.','30.'] | |
| if line.lstrip().startswith(tuple(numbers)): | |
| # If it's not the first line and the previous line is not blank, add a blank line | |
| if i > 0 and lines[i - 1].strip(): | |
| result.append('') | |
| # Append the current line to the result | |
| result.append(line) | |
| # Join the result list into a single string with newline characters | |
| return '<br>'.join(result) | |
| script_template = """ | |
| function setstate() {{ | |
| document.getElementById("user_id").value = "{user_id}"; | |
| document.getElementById("state").value = "{state}"; | |
| let userId = document.getElementById("user_id").value; | |
| let state = document.getElementById("state").value; | |
| console.log(`1 User ID: ${{userId}}, State: ${{state}}`); | |
| }} | |
| window.onload = setstate; | |
| """ | |
| queue = Queue() | |
| async def websocket_endpoint(websocket: WebSocket): | |
| global current_user | |
| await websocket.accept() | |
| print("WebSocket connection established.") | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| user_id = message.get("user_id") | |
| action = message.get("action") | |
| print(user_id, action) | |
| if action == "disconnect": | |
| if user_id == current_user: | |
| continue | |
| for item in queue.queue: | |
| if item == user_id: | |
| queue.queue.remove(item) | |
| print(f"User {user_id} disconnected.") | |
| except WebSocketDisconnect: | |
| print("WebSocket connection closed.") | |
| def form_get(request: Request): | |
| client_ip = request.client.host | |
| user_id = generate_user_id(client_ip) | |
| script = script_template.format(user_id=user_id, state="generate") | |
| print(client_ip,user_id) | |
| return Template(html_template).render(idea= "This is a example of the idea geneartion", error=None, reply_count=reply_count,button_text="Generate",loading_text="Generating content, Usually takes 3-4 minutes, please wait...",script=script) | |
| def form_post(request: Request,topic: str = Form(...), user_id: str = Form(...), state: str = Form(...)): | |
| global reply_count | |
| global current_user | |
| print("current queue length",len(queue.queue)) | |
| start_time = time.time() | |
| client_ip = request.client.host | |
| if user_id == "": | |
| user_id = generate_user_id(client_ip) | |
| if state == "": | |
| state = "generate" | |
| script = script_template.format(user_id=user_id, state=state) | |
| # if user_id in queue.queue: | |
| # # 如果用户在队列中,移除用户 | |
| # queue.queue.remove(user_id) | |
| # state = "generate" | |
| loading_text = "Generating content, Usually takes 3-4 minutes, please wait..." | |
| if state == "generate": | |
| if not queue.empty(): | |
| queue_len = queue.qsize() | |
| if queue_len + reply_count >= MAX_REPLIES_PER_DAY: | |
| error_message = "Today's maximum number of replies has been reached. Please try again tomorrow." | |
| return Template(html_template).render(idea="", error=error_message, reply_count=reply_count, button_text="Generate",loading_text=loading_text,script=script) | |
| if user_id in queue.queue: | |
| error_message = "You already have a request in the queue. Submitting a new request will cancel the previous request. Please confirm if you need to submit a new request. If so, click continue. There are currently {} requests being processed." | |
| else: | |
| error_message = "There are currently {} requests being processed. If you want to queue, please write your original topic and click the Continue button and you will enter the queue.".format(queue_len) | |
| new_state = "continue" | |
| new_button_text = "Continue" | |
| script = f""" | |
| function setstate() {{ | |
| document.getElementById("user_id").value = "{user_id}"; | |
| document.getElementById("state").value = "{new_state}"; | |
| let userId = document.getElementById("user_id").value; | |
| let state = document.getElementById("state").value; | |
| console.log(`1 User ID: ${{userId}}, State: ${{state}}`); | |
| document.getElementById("topic").value = "{topic}"; | |
| }} | |
| window.onload = setstate; | |
| """ | |
| print(f"current1 user_id={user_id}, state={new_state}") | |
| return Template(html_template).render(idea="", error=error_message, reply_count=reply_count, button_text=new_button_text,loading_text=f"Generating content, Usually takes {(queue_len+1)*3}-{(queue_len+1)*4} minutes, please wait...",script=script) | |
| queue.put(user_id) | |
| new_state = "generate" | |
| new_button_text = "Generate" | |
| queue_len = queue.qsize() | |
| script = script_template.format(user_id=user_id, state=new_state) | |
| # 判断当前是否轮到该用户,如果没轮到则一直等待到轮到为止 | |
| print(queue.queue[0], [user_id,topic]) | |
| while queue.queue[0] != user_id: | |
| counts = Counter(queue.queue) | |
| if counts[user_id] > 1: | |
| while counts[user_id] > 1: | |
| queue.queue.remove(user_id) | |
| counts[user_id] -= 1 | |
| return Template(html_template).render( | |
| idea="", | |
| error="Request was cancelled.", | |
| reply_count=reply_count, | |
| button_text="Generate", | |
| loading_text=loading_text, | |
| script=script | |
| ) | |
| time.sleep(2) | |
| continue | |
| try: | |
| with lock: | |
| current_user = user_id | |
| logging.info(f"Processing request for topic: {topic}") | |
| start_time = time.time() | |
| error_message = None | |
| idea = "" | |
| time_taken = 0 | |
| # 检查是否超过每日最大回复次数 | |
| if reply_count >= MAX_REPLIES_PER_DAY: | |
| error_message = "Today's maximum number of replies has been reached. Please try again tomorrow." | |
| logging.info(f"Today's maximum number of replies has been reached. Please try again tomorrow.") | |
| try: | |
| main_llm, cheap_llm = get_llms() | |
| deep_research_agent = DeepResearchAgent(llm=main_llm, cheap_llm=cheap_llm, improve_cnt=1, max_chain_length=5, min_chain_length=3, max_chain_numbers=1) | |
| print(f"begin to generate idea of topic {topic}") | |
| idea, related_experiments, entities, idea_chain, ideas, trend, future, human, year = deep_research_agent.generate_idea_with_chain(topic) | |
| idea = fix_markdown(idea) | |
| idea = markdown.markdown(idea) | |
| # 更新每日回复次数 | |
| reply_count += 1 | |
| end_time = time.time() | |
| time_taken = round(end_time - start_time, 2) | |
| logging.info(f"Successfully generated idea for topic: {topic}") | |
| except Exception as e: | |
| end_time = time.time() | |
| time_taken = round(end_time - start_time, 2) | |
| logging.error(f"Failed to generate idea for topic: {topic}, Error: {str(e)}") | |
| error_message = str(e) | |
| # 从队列中移除当前用户 | |
| finished = queue.get() | |
| print(f"finished: {finished}, still in queue: {queue.qsize()}") | |
| return Template(html_template).render(idea=idea, error=error_message, reply_count=reply_count, time_taken=time_taken,button_text=new_button_text,loading_text=loading_text,script=script) | |
| except Exception as e: | |
| error_message = str(e) | |
| queue.get() | |
| return Template(html_template).render(idea="", error=error_message, reply_count=reply_count, button_text="Generate",loading_text=loading_text,script=script) | |