Spaces:
Build error
Build error
| import os | |
| import tempfile | |
| import threading | |
| from pathlib import Path | |
| from typing import Any | |
| from zipfile import ZipFile | |
| import httpcore | |
| import httpx | |
| from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential | |
| from openhands.core.config import OpenHandsConfig | |
| from openhands.core.config.mcp_config import ( | |
| MCPConfig, | |
| MCPSSEServerConfig, | |
| MCPStdioServerConfig, | |
| ) | |
| from openhands.core.exceptions import ( | |
| AgentRuntimeTimeoutError, | |
| ) | |
| from openhands.events import EventStream | |
| from openhands.events.action import ( | |
| ActionConfirmationStatus, | |
| AgentThinkAction, | |
| BrowseInteractiveAction, | |
| BrowseURLAction, | |
| CmdRunAction, | |
| FileEditAction, | |
| FileReadAction, | |
| FileWriteAction, | |
| IPythonRunCellAction, | |
| ) | |
| from openhands.events.action.action import Action | |
| from openhands.events.action.files import FileEditSource | |
| from openhands.events.action.mcp import MCPAction | |
| from openhands.events.observation import ( | |
| AgentThinkObservation, | |
| ErrorObservation, | |
| NullObservation, | |
| Observation, | |
| UserRejectObservation, | |
| ) | |
| from openhands.events.serialization import event_to_dict, observation_from_dict | |
| from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS | |
| from openhands.integrations.provider import PROVIDER_TOKEN_TYPE | |
| from openhands.runtime.base import Runtime | |
| from openhands.runtime.plugins import PluginRequirement | |
| from openhands.runtime.utils.request import send_request | |
| from openhands.utils.http_session import HttpSession | |
| from openhands.utils.tenacity_stop import stop_if_should_exit | |
| def _is_retryable_error(exception): | |
| return isinstance( | |
| exception, (httpx.RemoteProtocolError, httpcore.RemoteProtocolError) | |
| ) | |
| class ActionExecutionClient(Runtime): | |
| """Base class for runtimes that interact with the action execution server. | |
| This class contains shared logic between DockerRuntime and RemoteRuntime | |
| for interacting with the HTTP server defined in action_execution_server.py. | |
| """ | |
| def __init__( | |
| self, | |
| config: OpenHandsConfig, | |
| event_stream: EventStream, | |
| sid: str = 'default', | |
| plugins: list[PluginRequirement] | None = None, | |
| env_vars: dict[str, str] | None = None, | |
| status_callback: Any | None = None, | |
| attach_to_existing: bool = False, | |
| headless_mode: bool = True, | |
| user_id: str | None = None, | |
| git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None, | |
| ): | |
| self.session = HttpSession() | |
| self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time | |
| self._runtime_closed: bool = False | |
| self._vscode_token: str | None = None # initial dummy value | |
| self._last_updated_mcp_stdio_servers: list[MCPStdioServerConfig] = [] | |
| super().__init__( | |
| config, | |
| event_stream, | |
| sid, | |
| plugins, | |
| env_vars, | |
| status_callback, | |
| attach_to_existing, | |
| headless_mode, | |
| user_id, | |
| git_provider_tokens, | |
| ) | |
| def action_execution_server_url(self) -> str: | |
| raise NotImplementedError('Action execution server URL is not implemented') | |
| def _send_action_server_request( | |
| self, | |
| method: str, | |
| url: str, | |
| **kwargs, | |
| ) -> httpx.Response: | |
| """Send a request to the action execution server. | |
| Args: | |
| method: HTTP method (GET, POST, etc.) | |
| url: URL to send the request to | |
| **kwargs: Additional arguments to pass to requests.request() | |
| Returns: | |
| Response from the server | |
| Raises: | |
| AgentRuntimeError: If the request fails | |
| """ | |
| return send_request(self.session, method, url, **kwargs) | |
| def check_if_alive(self) -> None: | |
| response = self._send_action_server_request( | |
| 'GET', | |
| f'{self.action_execution_server_url}/alive', | |
| timeout=5, | |
| ) | |
| assert response.is_closed | |
| def list_files(self, path: str | None = None) -> list[str]: | |
| """List files in the sandbox. | |
| If path is None, list files in the sandbox's initial working directory (e.g., /workspace). | |
| """ | |
| try: | |
| data = {} | |
| if path is not None: | |
| data['path'] = path | |
| response = self._send_action_server_request( | |
| 'POST', | |
| f'{self.action_execution_server_url}/list_files', | |
| json=data, | |
| timeout=10, | |
| ) | |
| assert response.is_closed | |
| response_json = response.json() | |
| assert isinstance(response_json, list) | |
| return response_json | |
| except httpx.TimeoutException: | |
| raise TimeoutError('List files operation timed out') | |
| def copy_from(self, path: str) -> Path: | |
| """Zip all files in the sandbox and return as a stream of bytes.""" | |
| try: | |
| params = {'path': path} | |
| with self.session.stream( | |
| 'GET', | |
| f'{self.action_execution_server_url}/download_files', | |
| params=params, | |
| timeout=30, | |
| ) as response: | |
| with tempfile.NamedTemporaryFile( | |
| suffix='.zip', delete=False | |
| ) as temp_file: | |
| for chunk in response.iter_bytes(): | |
| temp_file.write(chunk) | |
| temp_file.flush() | |
| return Path(temp_file.name) | |
| except httpx.TimeoutException: | |
| raise TimeoutError('Copy operation timed out') | |
| def copy_to( | |
| self, host_src: str, sandbox_dest: str, recursive: bool = False | |
| ) -> None: | |
| if not os.path.exists(host_src): | |
| raise FileNotFoundError(f'Source file {host_src} does not exist') | |
| temp_zip_path: str | None = None # Define temp_zip_path outside the try block | |
| try: | |
| params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()} | |
| file_to_upload = None | |
| upload_data = {} | |
| if recursive: | |
| # Create and write the zip file inside the try block | |
| with tempfile.NamedTemporaryFile( | |
| suffix='.zip', delete=False | |
| ) as temp_zip: | |
| temp_zip_path = temp_zip.name | |
| try: | |
| with ZipFile(temp_zip_path, 'w') as zipf: | |
| for root, _, files in os.walk(host_src): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| arcname = os.path.relpath( | |
| file_path, os.path.dirname(host_src) | |
| ) | |
| zipf.write(file_path, arcname) | |
| self.log( | |
| 'debug', | |
| f'Opening temporary zip file for upload: {temp_zip_path}', | |
| ) | |
| file_to_upload = open(temp_zip_path, 'rb') | |
| upload_data = {'file': file_to_upload} | |
| except Exception as e: | |
| # Ensure temp file is cleaned up if zipping fails | |
| if temp_zip_path and os.path.exists(temp_zip_path): | |
| os.unlink(temp_zip_path) | |
| raise e # Re-raise the exception after cleanup attempt | |
| else: | |
| file_to_upload = open(host_src, 'rb') | |
| upload_data = {'file': file_to_upload} | |
| params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()} | |
| response = self._send_action_server_request( | |
| 'POST', | |
| f'{self.action_execution_server_url}/upload_file', | |
| files=upload_data, | |
| params=params, | |
| timeout=300, | |
| ) | |
| self.log( | |
| 'debug', | |
| f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}', | |
| ) | |
| finally: | |
| if file_to_upload: | |
| file_to_upload.close() | |
| # Cleanup the temporary zip file if it was created | |
| if temp_zip_path and os.path.exists(temp_zip_path): | |
| try: | |
| os.unlink(temp_zip_path) | |
| except Exception as e: | |
| self.log( | |
| 'error', | |
| f'Failed to delete temporary zip file {temp_zip_path}: {e}', | |
| ) | |
| def get_vscode_token(self) -> str: | |
| if self.vscode_enabled and self.runtime_initialized: | |
| if self._vscode_token is not None: # cached value | |
| return self._vscode_token | |
| response = self._send_action_server_request( | |
| 'GET', | |
| f'{self.action_execution_server_url}/vscode/connection_token', | |
| timeout=10, | |
| ) | |
| response_json = response.json() | |
| assert isinstance(response_json, dict) | |
| if response_json['token'] is None: | |
| return '' | |
| self._vscode_token = response_json['token'] | |
| return response_json['token'] | |
| else: | |
| return '' | |
| def send_action_for_execution(self, action: Action) -> Observation: | |
| if ( | |
| isinstance(action, FileEditAction) | |
| and action.impl_source == FileEditSource.LLM_BASED_EDIT | |
| ): | |
| return self.llm_based_edit(action) | |
| # set timeout to default if not set | |
| if action.timeout is None: | |
| if isinstance(action, CmdRunAction) and action.blocking: | |
| raise RuntimeError('Blocking command with no timeout set') | |
| # We don't block the command if this is a default timeout action | |
| action.set_hard_timeout(self.config.sandbox.timeout, blocking=False) | |
| with self.action_semaphore: | |
| if not action.runnable: | |
| if isinstance(action, AgentThinkAction): | |
| return AgentThinkObservation('Your thought has been logged.') | |
| return NullObservation('') | |
| if ( | |
| hasattr(action, 'confirmation_state') | |
| and action.confirmation_state | |
| == ActionConfirmationStatus.AWAITING_CONFIRMATION | |
| ): | |
| return NullObservation('') | |
| action_type = action.action # type: ignore[attr-defined] | |
| if action_type not in ACTION_TYPE_TO_CLASS: | |
| raise ValueError(f'Action {action_type} does not exist.') | |
| if not hasattr(self, action_type): | |
| return ErrorObservation( | |
| f'Action {action_type} is not supported in the current runtime.', | |
| error_id='AGENT_ERROR$BAD_ACTION', | |
| ) | |
| if ( | |
| getattr(action, 'confirmation_state', None) | |
| == ActionConfirmationStatus.REJECTED | |
| ): | |
| return UserRejectObservation( | |
| 'Action has been rejected by the user! Waiting for further user input.' | |
| ) | |
| assert action.timeout is not None | |
| try: | |
| execution_action_body: dict[str, Any] = { | |
| 'action': event_to_dict(action), | |
| } | |
| response = self._send_action_server_request( | |
| 'POST', | |
| f'{self.action_execution_server_url}/execute_action', | |
| json=execution_action_body, | |
| # wait a few more seconds to get the timeout error from client side | |
| timeout=action.timeout + 5, | |
| ) | |
| assert response.is_closed | |
| output = response.json() | |
| obs = observation_from_dict(output) | |
| obs._cause = action.id # type: ignore[attr-defined] | |
| except httpx.TimeoutException: | |
| raise AgentRuntimeTimeoutError( | |
| f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s' | |
| ) | |
| return obs | |
| def run(self, action: CmdRunAction) -> Observation: | |
| return self.send_action_for_execution(action) | |
| def run_ipython(self, action: IPythonRunCellAction) -> Observation: | |
| return self.send_action_for_execution(action) | |
| def read(self, action: FileReadAction) -> Observation: | |
| return self.send_action_for_execution(action) | |
| def write(self, action: FileWriteAction) -> Observation: | |
| return self.send_action_for_execution(action) | |
| def edit(self, action: FileEditAction) -> Observation: | |
| return self.send_action_for_execution(action) | |
| def browse(self, action: BrowseURLAction) -> Observation: | |
| return self.send_action_for_execution(action) | |
| def browse_interactive(self, action: BrowseInteractiveAction) -> Observation: | |
| return self.send_action_for_execution(action) | |
| def get_mcp_config( | |
| self, extra_stdio_servers: list[MCPStdioServerConfig] | None = None | |
| ) -> MCPConfig: | |
| import sys | |
| # Check if we're on Windows - MCP is disabled on Windows | |
| if sys.platform == 'win32': | |
| # Return empty MCP config on Windows | |
| self.log('debug', 'MCP is disabled on Windows, returning empty config') | |
| return MCPConfig(sse_servers=[], stdio_servers=[]) | |
| # Add the runtime as another MCP server | |
| updated_mcp_config = self.config.mcp.model_copy() | |
| # Get current stdio servers | |
| current_stdio_servers: list[MCPStdioServerConfig] = list( | |
| updated_mcp_config.stdio_servers | |
| ) | |
| if extra_stdio_servers: | |
| current_stdio_servers.extend(extra_stdio_servers) | |
| # Check if there are any new servers using the __eq__ operator | |
| new_servers = [ | |
| server | |
| for server in current_stdio_servers | |
| if server not in self._last_updated_mcp_stdio_servers | |
| ] | |
| self.log( | |
| 'debug', | |
| f'adding {len(new_servers)} new stdio servers to MCP config: {new_servers}', | |
| ) | |
| # Only send update request if there are new servers | |
| if new_servers: | |
| # Use a union of current servers and last updated servers for the update | |
| # This ensures we don't lose any servers that might be missing from either list | |
| combined_servers = current_stdio_servers.copy() | |
| for server in self._last_updated_mcp_stdio_servers: | |
| if server not in combined_servers: | |
| combined_servers.append(server) | |
| stdio_tools = [ | |
| server.model_dump(mode='json') for server in combined_servers | |
| ] | |
| stdio_tools.sort(key=lambda x: x.get('name', '')) # Sort by server name | |
| self.log( | |
| 'debug', | |
| f'Updating MCP server with {len(new_servers)} new stdio servers (total: {len(combined_servers)})', | |
| ) | |
| response = self._send_action_server_request( | |
| 'POST', | |
| f'{self.action_execution_server_url}/update_mcp_server', | |
| json=stdio_tools, | |
| timeout=60, | |
| ) | |
| result = response.json() | |
| if response.status_code != 200: | |
| self.log('warning', f'Failed to update MCP server: {response.text}') | |
| else: | |
| if result.get('router_error_log'): | |
| self.log( | |
| 'warning', | |
| f'Some MCP servers failed to be added: {result["router_error_log"]}', | |
| ) | |
| # Update our cached list with combined servers after successful update | |
| self._last_updated_mcp_stdio_servers = combined_servers.copy() | |
| self.log( | |
| 'debug', | |
| f'Successfully updated MCP stdio servers, now tracking {len(combined_servers)} servers', | |
| ) | |
| self.log( | |
| 'info', | |
| f'Updated MCP config: {updated_mcp_config.sse_servers}', | |
| ) | |
| else: | |
| self.log('debug', 'No new stdio servers to update') | |
| if len(self._last_updated_mcp_stdio_servers) > 0: | |
| # We should always include the runtime as an MCP server whenever there's > 0 stdio servers | |
| updated_mcp_config.sse_servers.append( | |
| MCPSSEServerConfig( | |
| url=self.action_execution_server_url.rstrip('/') + '/mcp/sse', | |
| api_key=self.session_api_key, | |
| ) | |
| ) | |
| return updated_mcp_config | |
| async def call_tool_mcp(self, action: MCPAction) -> Observation: | |
| import sys | |
| from openhands.events.observation import ErrorObservation | |
| # Check if we're on Windows - MCP is disabled on Windows | |
| if sys.platform == 'win32': | |
| self.log('info', 'MCP functionality is disabled on Windows') | |
| return ErrorObservation('MCP functionality is not available on Windows') | |
| # Import here to avoid circular imports | |
| from openhands.mcp.utils import call_tool_mcp as call_tool_mcp_handler | |
| from openhands.mcp.utils import create_mcp_clients | |
| # Get the updated MCP config | |
| updated_mcp_config = self.get_mcp_config() | |
| self.log( | |
| 'debug', | |
| f'Creating MCP clients with servers: {updated_mcp_config.sse_servers}', | |
| ) | |
| # Create clients for this specific operation | |
| mcp_clients = await create_mcp_clients( | |
| updated_mcp_config.sse_servers, updated_mcp_config.shttp_servers, self.sid | |
| ) | |
| # Call the tool and return the result | |
| # No need for try/finally since disconnect() is now just resetting state | |
| result = await call_tool_mcp_handler(mcp_clients, action) | |
| return result | |
| def close(self) -> None: | |
| # Make sure we don't close the session multiple times | |
| # Can happen in evaluation | |
| if self._runtime_closed: | |
| return | |
| self._runtime_closed = True | |
| self.session.close() | |