Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import asyncio | |
| import logging | |
| import os | |
| import re | |
| from uuid import uuid4 | |
| import tornado | |
| import tornado.websocket | |
| from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed | |
| from tornado.escape import json_decode, json_encode, url_escape | |
| from tornado.httpclient import AsyncHTTPClient, HTTPRequest | |
| from tornado.ioloop import PeriodicCallback | |
| from tornado.websocket import websocket_connect | |
| logging.basicConfig(level=logging.INFO) | |
| def strip_ansi(o: str) -> str: | |
| """Removes ANSI escape sequences from `o`, as defined by ECMA-048 in | |
| http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf | |
| # https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py | |
| >>> strip_ansi("\\033[33mLorem ipsum\\033[0m") | |
| 'Lorem ipsum' | |
| >>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.") | |
| 'Lorem Ipsum sit\\namet.' | |
| >>> strip_ansi("") | |
| '' | |
| >>> strip_ansi("\\x1b[0m") | |
| '' | |
| >>> strip_ansi("Lorem") | |
| 'Lorem' | |
| >>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m') | |
| 'Lorem ipsum' | |
| >>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m') | |
| 'Lorem dolor sit ipsum' | |
| """ | |
| # pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/') | |
| pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m') | |
| stripped = pattern.sub('', o) | |
| return stripped | |
| class JupyterKernel: | |
| def __init__(self, url_suffix: str, convid: str, lang: str = 'python') -> None: | |
| self.base_url = f'http://{url_suffix}' | |
| self.base_ws_url = f'ws://{url_suffix}' | |
| self.lang = lang | |
| self.kernel_id: str | None = None | |
| self.ws: tornado.websocket.WebSocketClientConnection | None = None | |
| self.convid = convid | |
| logging.info( | |
| f'Jupyter kernel created for conversation {convid} at {url_suffix}' | |
| ) | |
| self.heartbeat_interval = 10000 # 10 seconds | |
| self.heartbeat_callback: PeriodicCallback | None = None | |
| self.initialized = False | |
| async def initialize(self) -> None: | |
| await self.execute(r'%colors nocolor') | |
| # pre-defined tools | |
| self.tools_to_run: list[str] = [ | |
| # TODO: You can add code for your pre-defined tools here | |
| ] | |
| for tool in self.tools_to_run: | |
| res = await self.execute(tool) | |
| logging.info(f'Tool [{tool}] initialized:\n{res}') | |
| self.initialized = True | |
| async def _send_heartbeat(self) -> None: | |
| if not self.ws: | |
| return | |
| try: | |
| self.ws.ping() | |
| # logging.info('Heartbeat sent...') | |
| except tornado.iostream.StreamClosedError: | |
| # logging.info('Heartbeat failed, reconnecting...') | |
| try: | |
| await self._connect() | |
| except ConnectionRefusedError: | |
| logging.info( | |
| 'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?' | |
| ) | |
| async def _connect(self) -> None: | |
| if self.ws: | |
| self.ws.close() | |
| self.ws = None | |
| client = AsyncHTTPClient() | |
| if not self.kernel_id: | |
| n_tries = 5 | |
| while n_tries > 0: | |
| try: | |
| response = await client.fetch( | |
| '{}/api/kernels'.format(self.base_url), | |
| method='POST', | |
| body=json_encode({'name': self.lang}), | |
| ) | |
| kernel = json_decode(response.body) | |
| self.kernel_id = kernel['id'] | |
| break | |
| except Exception: | |
| # kernels are not ready yet | |
| n_tries -= 1 | |
| await asyncio.sleep(1) | |
| if n_tries == 0: | |
| raise ConnectionRefusedError('Failed to connect to kernel') | |
| ws_req = HTTPRequest( | |
| url='{}/api/kernels/{}/channels'.format( | |
| self.base_ws_url, url_escape(self.kernel_id) | |
| ) | |
| ) | |
| self.ws = await websocket_connect(ws_req) | |
| logging.info('Connected to kernel websocket') | |
| # Setup heartbeat | |
| if self.heartbeat_callback: | |
| self.heartbeat_callback.stop() | |
| self.heartbeat_callback = PeriodicCallback( | |
| self._send_heartbeat, self.heartbeat_interval | |
| ) | |
| self.heartbeat_callback.start() | |
| # type: ignore | |
| async def execute( | |
| self, code: str, timeout: int = 120 | |
| ) -> dict[str, list[str] | str]: | |
| if not self.ws or self.ws.stream.closed(): | |
| await self._connect() | |
| msg_id = uuid4().hex | |
| assert self.ws is not None | |
| res = await self.ws.write_message( | |
| json_encode( | |
| { | |
| 'header': { | |
| 'username': '', | |
| 'version': '5.0', | |
| 'session': '', | |
| 'msg_id': msg_id, | |
| 'msg_type': 'execute_request', | |
| }, | |
| 'parent_header': {}, | |
| 'channel': 'shell', | |
| 'content': { | |
| 'code': code, | |
| 'silent': False, | |
| 'store_history': False, | |
| 'user_expressions': {}, | |
| 'allow_stdin': False, | |
| }, | |
| 'metadata': {}, | |
| 'buffers': {}, | |
| } | |
| ) | |
| ) | |
| logging.info(f'Executed code in jupyter kernel:\n{res}') | |
| outputs: list[dict] = [] | |
| async def wait_for_messages() -> bool: | |
| execution_done = False | |
| while not execution_done: | |
| assert self.ws is not None | |
| msg = await self.ws.read_message() | |
| if msg is None: | |
| continue | |
| msg_dict = json_decode(msg) | |
| msg_type = msg_dict['msg_type'] | |
| parent_msg_id = msg_dict['parent_header'].get('msg_id', None) | |
| if parent_msg_id != msg_id: | |
| continue | |
| if os.environ.get('DEBUG'): | |
| logging.info( | |
| f'MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg_dict["content"]}' | |
| ) | |
| if msg_type == 'error': | |
| traceback = '\n'.join(msg_dict['content']['traceback']) | |
| outputs.append({'type': 'text', 'content': traceback}) | |
| execution_done = True | |
| elif msg_type == 'stream': | |
| outputs.append( | |
| {'type': 'text', 'content': msg_dict['content']['text']} | |
| ) | |
| elif msg_type in ['execute_result', 'display_data']: | |
| outputs.append( | |
| { | |
| 'type': 'text', | |
| 'content': msg_dict['content']['data']['text/plain'], | |
| } | |
| ) | |
| if 'image/png' in msg_dict['content']['data']: | |
| # Store image data in structured format | |
| image_url = f'data:image/png;base64,{msg_dict["content"]["data"]["image/png"]}' | |
| outputs.append({'type': 'image', 'content': image_url}) | |
| elif msg_type == 'execute_reply': | |
| execution_done = True | |
| return execution_done | |
| async def interrupt_kernel() -> None: | |
| client = AsyncHTTPClient() | |
| if self.kernel_id is None: | |
| return | |
| interrupt_response = await client.fetch( | |
| f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt', | |
| method='POST', | |
| body=json_encode({'kernel_id': self.kernel_id}), | |
| ) | |
| logging.info(f'Kernel interrupted: {interrupt_response}') | |
| try: | |
| execution_done = await asyncio.wait_for(wait_for_messages(), timeout) | |
| except asyncio.TimeoutError: | |
| await interrupt_kernel() | |
| return {'text': f'[Execution timed out ({timeout} seconds).]', 'images': []} | |
| # Process structured outputs | |
| text_outputs = [] | |
| image_outputs = [] | |
| for output in outputs: | |
| if output['type'] == 'text': | |
| text_outputs.append(output['content']) | |
| elif output['type'] == 'image': | |
| image_outputs.append(output['content']) | |
| if not text_outputs and execution_done: | |
| text_content = '[Code executed successfully with no output]' | |
| else: | |
| text_content = ''.join(text_outputs) | |
| # Remove ANSI from text content | |
| text_content = strip_ansi(text_content) | |
| # Return a dictionary with text content and image URLs | |
| return {'text': text_content, 'images': image_outputs} | |
| async def shutdown_async(self) -> None: | |
| if self.kernel_id: | |
| client = AsyncHTTPClient() | |
| await client.fetch( | |
| '{}/api/kernels/{}'.format(self.base_url, self.kernel_id), | |
| method='DELETE', | |
| ) | |
| self.kernel_id = None | |
| if self.ws: | |
| self.ws.close() | |
| self.ws = None | |
| class ExecuteHandler(tornado.web.RequestHandler): | |
| def initialize(self, jupyter_kernel: JupyterKernel) -> None: | |
| self.jupyter_kernel = jupyter_kernel | |
| async def post(self) -> None: | |
| data = json_decode(self.request.body) | |
| code = data.get('code') | |
| if not code: | |
| self.set_status(400) | |
| self.write('Missing code') | |
| return | |
| output = await self.jupyter_kernel.execute(code) | |
| # Set content type to JSON and return the structured output | |
| self.set_header('Content-Type', 'application/json') | |
| self.write(json_encode(output)) | |
| def make_app() -> tornado.web.Application: | |
| jupyter_kernel = JupyterKernel( | |
| f'localhost:{os.environ.get("JUPYTER_GATEWAY_PORT", "8888")}', | |
| os.environ.get('JUPYTER_GATEWAY_KERNEL_ID', 'default'), | |
| ) | |
| asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize()) | |
| return tornado.web.Application( | |
| [ | |
| (r'/execute', ExecuteHandler, {'jupyter_kernel': jupyter_kernel}), | |
| ] | |
| ) | |
| if __name__ == '__main__': | |
| app = make_app() | |
| app.listen(os.environ.get('JUPYTER_EXEC_SERVER_PORT')) | |
| tornado.ioloop.IOLoop.current().start() | |