Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| import random | |
| from typing import Dict, List, Union | |
| import janus | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.requests import Request | |
| from pydantic import BaseModel, Field | |
| from sse_starlette.sse import EventSourceResponse | |
| from mindsearch.agent import init_agent | |
| def parse_arguments(): | |
| import argparse | |
| parser = argparse.ArgumentParser(description="MindSearch API") | |
| parser.add_argument("--host", default="0.0.0.0", type=str, help="Service host") | |
| parser.add_argument("--port", default=8002, type=int, help="Service port") | |
| parser.add_argument("--lang", default="cn", type=str, help="Language") | |
| parser.add_argument("--model_format", default="internlm_server", type=str, help="Model format") | |
| parser.add_argument("--search_engine", default="BingSearch", type=str, help="Search engine") | |
| parser.add_argument("--asy", default=False, action="store_true", help="Agent mode") | |
| return parser.parse_args() | |
| args = parse_arguments() | |
| app = FastAPI(docs_url="/") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class GenerationParams(BaseModel): | |
| inputs: Union[str, List[Dict]] | |
| session_id: int = Field(default_factory=lambda: random.randint(0, 999999)) | |
| agent_cfg: Dict = dict() | |
| def _postprocess_agent_message(message: dict) -> dict: | |
| content, fmt = message["content"], message["formatted"] | |
| current_node = content["current_node"] if isinstance(content, dict) else None | |
| if current_node: | |
| message["content"] = None | |
| for key in ["ref2url"]: | |
| fmt.pop(key, None) | |
| graph = fmt["node"] | |
| for key in graph.copy(): | |
| if key != current_node: | |
| graph.pop(key) | |
| if current_node not in ["root", "response"]: | |
| node = graph[current_node] | |
| for key in ["memory", "session_id"]: | |
| node.pop(key, None) | |
| node_fmt = node["response"]["formatted"] | |
| if isinstance(node_fmt, dict) and "thought" in node_fmt and "action" in node_fmt: | |
| node["response"]["content"] = None | |
| node_fmt["thought"] = ( | |
| node_fmt["thought"] and node_fmt["thought"].split("<|action_start|>")[0] | |
| ) | |
| if isinstance(node_fmt["action"], str): | |
| node_fmt["action"] = node_fmt["action"].split("<|action_end|>")[0] | |
| else: | |
| if isinstance(fmt, dict) and "thought" in fmt and "action" in fmt: | |
| message["content"] = None | |
| fmt["thought"] = fmt["thought"] and fmt["thought"].split("<|action_start|>")[0] | |
| if isinstance(fmt["action"], str): | |
| fmt["action"] = fmt["action"].split("<|action_end|>")[0] | |
| for key in ["node"]: | |
| fmt.pop(key, None) | |
| return dict(current_node=current_node, response=message) | |
| async def run(request: GenerationParams, _request: Request): | |
| async def generate(): | |
| try: | |
| queue = janus.Queue() | |
| stop_event = asyncio.Event() | |
| # Wrapping a sync generator as an async generator using run_in_executor | |
| def sync_generator_wrapper(): | |
| try: | |
| for response in agent(inputs, session_id=session_id): | |
| queue.sync_q.put(response) | |
| except Exception as e: | |
| logging.exception(f"Exception in sync_generator_wrapper: {e}") | |
| finally: | |
| # Notify async_generator_wrapper that the data generation is complete. | |
| queue.sync_q.put(None) | |
| async def async_generator_wrapper(): | |
| loop = asyncio.get_event_loop() | |
| loop.run_in_executor(None, sync_generator_wrapper) | |
| while True: | |
| response = await queue.async_q.get() | |
| if response is None: # Ensure that all elements are consumed | |
| break | |
| yield response | |
| stop_event.set() # Inform sync_generator_wrapper to stop | |
| async for message in async_generator_wrapper(): | |
| response_json = json.dumps( | |
| _postprocess_agent_message(message.model_dump()), | |
| ensure_ascii=False, | |
| ) | |
| yield {"data": response_json} | |
| if await _request.is_disconnected(): | |
| break | |
| except Exception as exc: | |
| msg = "An error occurred while generating the response." | |
| logging.exception(msg) | |
| response_json = json.dumps( | |
| dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False | |
| ) | |
| yield {"data": response_json} | |
| finally: | |
| await stop_event.wait() # Waiting for async_generator_wrapper to stop | |
| queue.close() | |
| await queue.wait_closed() | |
| agent.agent.memory.memory_map.pop(session_id, None) | |
| inputs = request.inputs | |
| session_id = request.session_id | |
| agent = init_agent( | |
| lang=args.lang, | |
| model_format=args.model_format, | |
| search_engine=args.search_engine, | |
| ) | |
| return EventSourceResponse(generate(), ping=300) | |
| async def run_async(request: GenerationParams, _request: Request): | |
| async def generate(): | |
| try: | |
| async for message in agent(inputs, session_id=session_id): | |
| response_json = json.dumps( | |
| _postprocess_agent_message(message.model_dump()), | |
| ensure_ascii=False, | |
| ) | |
| yield {"data": response_json} | |
| if await _request.is_disconnected(): | |
| break | |
| except Exception as exc: | |
| msg = "An error occurred while generating the response." | |
| logging.exception(msg) | |
| response_json = json.dumps( | |
| dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False | |
| ) | |
| yield {"data": response_json} | |
| finally: | |
| agent.agent.memory.memory_map.pop(session_id, None) | |
| inputs = request.inputs | |
| session_id = request.session_id | |
| agent = init_agent( | |
| lang=args.lang, | |
| model_format=args.model_format, | |
| search_engine=args.search_engine, | |
| use_async=True, | |
| ) | |
| return EventSourceResponse(generate(), ping=300) | |
| app.add_api_route("/solve", run_async if args.asy else run, methods=["POST"]) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |