Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import logging | |
| import uuid | |
| from typing import Optional | |
| import aiohttp | |
| import websockets | |
| from pydantic import BaseModel | |
| from open_webui.env import SRC_LOG_LEVELS | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(SRC_LOG_LEVELS["MAIN"]) | |
| class ResultModel(BaseModel): | |
| """ | |
| Execute Code Result Model | |
| """ | |
| stdout: Optional[str] = "" | |
| stderr: Optional[str] = "" | |
| result: Optional[str] = "" | |
| class JupyterCodeExecuter: | |
| """ | |
| Execute code in jupyter notebook | |
| """ | |
| def __init__( | |
| self, | |
| base_url: str, | |
| code: str, | |
| token: str = "", | |
| password: str = "", | |
| timeout: int = 60, | |
| ): | |
| """ | |
| :param base_url: Jupyter server URL (e.g., "http://localhost:8888") | |
| :param code: Code to execute | |
| :param token: Jupyter authentication token (optional) | |
| :param password: Jupyter password (optional) | |
| :param timeout: WebSocket timeout in seconds (default: 60s) | |
| """ | |
| self.base_url = base_url | |
| self.code = code | |
| self.token = token | |
| self.password = password | |
| self.timeout = timeout | |
| self.kernel_id = "" | |
| if self.base_url[-1] != "/": | |
| self.base_url += "/" | |
| self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url) | |
| self.params = {} | |
| self.result = ResultModel() | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| if self.kernel_id: | |
| try: | |
| async with self.session.delete( | |
| f"api/kernels/{self.kernel_id}", params=self.params | |
| ) as response: | |
| response.raise_for_status() | |
| except Exception as err: | |
| logger.exception("close kernel failed, %s", err) | |
| await self.session.close() | |
| async def run(self) -> ResultModel: | |
| try: | |
| await self.sign_in() | |
| await self.init_kernel() | |
| await self.execute_code() | |
| except Exception as err: | |
| logger.exception("execute code failed, %s", err) | |
| self.result.stderr = f"Error: {err}" | |
| return self.result | |
| async def sign_in(self) -> None: | |
| # password authentication | |
| if self.password and not self.token: | |
| async with self.session.get("login") as response: | |
| response.raise_for_status() | |
| xsrf_token = response.cookies["_xsrf"].value | |
| if not xsrf_token: | |
| raise ValueError("_xsrf token not found") | |
| self.session.cookie_jar.update_cookies(response.cookies) | |
| self.session.headers.update({"X-XSRFToken": xsrf_token}) | |
| async with self.session.post( | |
| "login", | |
| data={"_xsrf": xsrf_token, "password": self.password}, | |
| allow_redirects=False, | |
| ) as response: | |
| response.raise_for_status() | |
| self.session.cookie_jar.update_cookies(response.cookies) | |
| # token authentication | |
| if self.token: | |
| self.params.update({"token": self.token}) | |
| async def init_kernel(self) -> None: | |
| async with self.session.post(url="api/kernels", params=self.params) as response: | |
| response.raise_for_status() | |
| kernel_data = await response.json() | |
| self.kernel_id = kernel_data["id"] | |
| def init_ws(self) -> (str, dict): | |
| ws_base = self.base_url.replace("http", "ws", 1) | |
| ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()]) | |
| websocket_url = f"{ws_base}api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}" | |
| ws_headers = {} | |
| if self.password and not self.token: | |
| ws_headers = { | |
| "Cookie": "; ".join( | |
| [ | |
| f"{cookie.key}={cookie.value}" | |
| for cookie in self.session.cookie_jar | |
| ] | |
| ), | |
| **self.session.headers, | |
| } | |
| return websocket_url, ws_headers | |
| async def execute_code(self) -> None: | |
| # initialize ws | |
| websocket_url, ws_headers = self.init_ws() | |
| # execute | |
| async with websockets.connect( | |
| websocket_url, additional_headers=ws_headers | |
| ) as ws: | |
| await self.execute_in_jupyter(ws) | |
| async def execute_in_jupyter(self, ws) -> None: | |
| # send message | |
| msg_id = uuid.uuid4().hex | |
| await ws.send( | |
| json.dumps( | |
| { | |
| "header": { | |
| "msg_id": msg_id, | |
| "msg_type": "execute_request", | |
| "username": "user", | |
| "session": uuid.uuid4().hex, | |
| "date": "", | |
| "version": "5.3", | |
| }, | |
| "parent_header": {}, | |
| "metadata": {}, | |
| "content": { | |
| "code": self.code, | |
| "silent": False, | |
| "store_history": True, | |
| "user_expressions": {}, | |
| "allow_stdin": False, | |
| "stop_on_error": True, | |
| }, | |
| "channel": "shell", | |
| } | |
| ) | |
| ) | |
| # parse message | |
| stdout, stderr, result = "", "", [] | |
| while True: | |
| try: | |
| # wait for message | |
| message = await asyncio.wait_for(ws.recv(), self.timeout) | |
| message_data = json.loads(message) | |
| # msg id not match, skip | |
| if message_data.get("parent_header", {}).get("msg_id") != msg_id: | |
| continue | |
| # check message type | |
| msg_type = message_data.get("msg_type") | |
| match msg_type: | |
| case "stream": | |
| if message_data["content"]["name"] == "stdout": | |
| stdout += message_data["content"]["text"] | |
| elif message_data["content"]["name"] == "stderr": | |
| stderr += message_data["content"]["text"] | |
| case "execute_result" | "display_data": | |
| data = message_data["content"]["data"] | |
| if "image/png" in data: | |
| result.append(f"data:image/png;base64,{data['image/png']}") | |
| elif "text/plain" in data: | |
| result.append(data["text/plain"]) | |
| case "error": | |
| stderr += "\n".join(message_data["content"]["traceback"]) | |
| case "status": | |
| if message_data["content"]["execution_state"] == "idle": | |
| break | |
| except asyncio.TimeoutError: | |
| stderr += "\nExecution timed out." | |
| break | |
| self.result.stdout = stdout.strip() | |
| self.result.stderr = stderr.strip() | |
| self.result.result = "\n".join(result).strip() if result else "" | |
| async def execute_code_jupyter( | |
| base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60 | |
| ) -> dict: | |
| async with JupyterCodeExecuter( | |
| base_url, code, token, password, timeout | |
| ) as executor: | |
| result = await executor.run() | |
| return result.model_dump() | |