Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| from copy import deepcopy | |
| from dataclasses import asdict | |
| from typing import Dict, List, Union | |
| import janus | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from lagent.schema import AgentStatusCode | |
| from pydantic import BaseModel | |
| from sse_starlette.sse import EventSourceResponse | |
| from mindsearch.agent import init_agent | |
| def parse_arguments(): | |
| import argparse | |
| parser = argparse.ArgumentParser(description='MindSearch API') | |
| parser.add_argument('--lang', default='cn', type=str, help='Language') | |
| parser.add_argument('--model_format', | |
| default='internlm_server', | |
| type=str, | |
| help='Model format') | |
| parser.add_argument('--search_engine', | |
| default='DuckDuckGoSearch', | |
| type=str, | |
| help='Search engine') | |
| return parser.parse_args() | |
| args = parse_arguments() | |
| app = FastAPI(docs_url='/') | |
| app.add_middleware(CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=['*'], | |
| allow_headers=['*']) | |
| class GenerationParams(BaseModel): | |
| inputs: Union[str, List[Dict]] | |
| agent_cfg: Dict = dict() | |
| async def run(request: GenerationParams): | |
| def convert_adjacency_to_tree(adjacency_input, root_name): | |
| def build_tree(node_name): | |
| node = {'name': node_name, 'children': []} | |
| if node_name in adjacency_input: | |
| for child in adjacency_input[node_name]: | |
| child_node = build_tree(child['name']) | |
| child_node['state'] = child['state'] | |
| child_node['id'] = child['id'] | |
| node['children'].append(child_node) | |
| return node | |
| return build_tree(root_name) | |
| async def generate(): | |
| try: | |
| queue = janus.Queue() | |
| stop_event = asyncio.Event() | |
| # Wrapping a sync generator as an async generator using run_in_executor | |
| def sync_generator_wrapper(): | |
| try: | |
| for response in agent.stream_chat(inputs): | |
| queue.sync_q.put(response) | |
| except Exception as e: | |
| logging.exception( | |
| f'Exception in sync_generator_wrapper: {e}') | |
| finally: | |
| # Notify async_generator_wrapper that the data generation is complete. | |
| queue.sync_q.put(None) | |
| async def async_generator_wrapper(): | |
| loop = asyncio.get_event_loop() | |
| loop.run_in_executor(None, sync_generator_wrapper) | |
| while True: | |
| response = await queue.async_q.get() | |
| if response is None: # Ensure that all elements are consumed | |
| break | |
| yield response | |
| if not isinstance( | |
| response, | |
| tuple) and response.state == AgentStatusCode.END: | |
| break | |
| stop_event.set() # Inform sync_generator_wrapper to stop | |
| async for response in async_generator_wrapper(): | |
| if isinstance(response, tuple): | |
| agent_return, node_name = response | |
| else: | |
| agent_return = response | |
| node_name = None | |
| origin_adj = deepcopy(agent_return.adjacency_list) | |
| adjacency_list = convert_adjacency_to_tree( | |
| agent_return.adjacency_list, 'root') | |
| assert adjacency_list[ | |
| 'name'] == 'root' and 'children' in adjacency_list | |
| agent_return.adjacency_list = adjacency_list['children'] | |
| agent_return = asdict(agent_return) | |
| agent_return['adj'] = origin_adj | |
| response_json = json.dumps(dict(response=agent_return, | |
| current_node=node_name), | |
| ensure_ascii=False) | |
| yield {'data': response_json} | |
| # yield f'data: {response_json}\n\n' | |
| except Exception as exc: | |
| msg = 'An error occurred while generating the response.' | |
| logging.exception(msg) | |
| response_json = json.dumps( | |
| dict(error=dict(msg=msg, details=str(exc))), | |
| ensure_ascii=False) | |
| yield {'data': response_json} | |
| # yield f'data: {response_json}\n\n' | |
| finally: | |
| await stop_event.wait( | |
| ) # Waiting for async_generator_wrapper to stop | |
| queue.close() | |
| await queue.wait_closed() | |
| inputs = request.inputs | |
| agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine) | |
| return EventSourceResponse(generate()) | |
| if __name__ == '__main__': | |
| import uvicorn | |
| uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info') | |