Spaces:
Build error
Build error
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from openhands.controller.agent import Agent | |
| from openhands.controller.agent_controller import AgentController | |
| from openhands.controller.state.state import State | |
| from openhands.core.config import LLMConfig, OpenHandsConfig | |
| from openhands.core.config.agent_config import AgentConfig | |
| from openhands.events import EventStream, EventStreamSubscriber | |
| from openhands.llm import LLM | |
| from openhands.llm.metrics import Metrics | |
| from openhands.memory.memory import Memory | |
| from openhands.runtime.impl.action_execution.action_execution_client import ( | |
| ActionExecutionClient, | |
| ) | |
| from openhands.server.session.agent_session import AgentSession | |
| from openhands.storage.memory import InMemoryFileStore | |
| def mock_agent(): | |
| """Create a properly configured mock agent with all required nested attributes""" | |
| # Create the base mocks | |
| agent = MagicMock(spec=Agent) | |
| llm = MagicMock(spec=LLM) | |
| metrics = MagicMock(spec=Metrics) | |
| llm_config = MagicMock(spec=LLMConfig) | |
| agent_config = MagicMock(spec=AgentConfig) | |
| # Configure the LLM config | |
| llm_config.model = 'test-model' | |
| llm_config.base_url = 'http://test' | |
| llm_config.max_message_chars = 1000 | |
| # Configure the agent config | |
| agent_config.disabled_microagents = [] | |
| agent_config.enable_mcp = True | |
| # Set up the chain of mocks | |
| llm.metrics = metrics | |
| llm.config = llm_config | |
| agent.llm = llm | |
| agent.name = 'test-agent' | |
| agent.sandbox_plugins = [] | |
| agent.config = agent_config | |
| agent.prompt_manager = MagicMock() | |
| return agent | |
| async def test_agent_session_start_with_no_state(mock_agent): | |
| """Test that AgentSession.start() works correctly when there's no state to restore""" | |
| # Setup | |
| file_store = InMemoryFileStore({}) | |
| session = AgentSession( | |
| sid='test-session', | |
| file_store=file_store, | |
| ) | |
| # Create a mock runtime and set it up | |
| mock_runtime = MagicMock(spec=ActionExecutionClient) | |
| # Mock the runtime creation to set up the runtime attribute | |
| async def mock_create_runtime(*args, **kwargs): | |
| session.runtime = mock_runtime | |
| return True | |
| session._create_runtime = AsyncMock(side_effect=mock_create_runtime) | |
| # Create a mock EventStream with no events | |
| mock_event_stream = MagicMock(spec=EventStream) | |
| mock_event_stream.get_events.return_value = [] | |
| mock_event_stream.subscribe = MagicMock() | |
| mock_event_stream.get_latest_event_id.return_value = 0 | |
| # Inject the mock event stream into the session | |
| session.event_stream = mock_event_stream | |
| # Create a spy on set_initial_state | |
| class SpyAgentController(AgentController): | |
| set_initial_state_call_count = 0 | |
| test_initial_state = None | |
| def set_initial_state(self, *args, state=None, **kwargs): | |
| self.set_initial_state_call_count += 1 | |
| self.test_initial_state = state | |
| super().set_initial_state(*args, state=state, **kwargs) | |
| # Create a real Memory instance with the mock event stream | |
| memory = Memory(event_stream=mock_event_stream, sid='test-session') | |
| memory.microagents_dir = 'test-dir' | |
| # Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession | |
| with ( | |
| patch( | |
| 'openhands.server.session.agent_session.AgentController', SpyAgentController | |
| ), | |
| patch( | |
| 'openhands.server.session.agent_session.EventStream', | |
| return_value=mock_event_stream, | |
| ), | |
| patch( | |
| 'openhands.controller.state.state.State.restore_from_session', | |
| side_effect=Exception('No state found'), | |
| ), | |
| patch('openhands.server.session.agent_session.Memory', return_value=memory), | |
| ): | |
| await session.start( | |
| runtime_name='test-runtime', | |
| config=OpenHandsConfig(), | |
| agent=mock_agent, | |
| max_iterations=10, | |
| ) | |
| # Verify EventStream.subscribe was called with correct parameters | |
| mock_event_stream.subscribe.assert_any_call( | |
| EventStreamSubscriber.AGENT_CONTROLLER, | |
| session.controller.on_event, | |
| session.controller.id, | |
| ) | |
| mock_event_stream.subscribe.assert_any_call( | |
| EventStreamSubscriber.MEMORY, | |
| session.memory.on_event, | |
| session.controller.id, | |
| ) | |
| # Verify set_initial_state was called once with None as state | |
| assert session.controller.set_initial_state_call_count == 1 | |
| assert session.controller.test_initial_state is None | |
| assert session.controller.state.max_iterations == 10 | |
| assert session.controller.agent.name == 'test-agent' | |
| assert session.controller.state.start_id == 0 | |
| assert session.controller.state.end_id == -1 | |
| async def test_agent_session_start_with_restored_state(mock_agent): | |
| """Test that AgentSession.start() works correctly when there's a state to restore""" | |
| # Setup | |
| file_store = InMemoryFileStore({}) | |
| session = AgentSession( | |
| sid='test-session', | |
| file_store=file_store, | |
| ) | |
| # Create a mock runtime and set it up | |
| mock_runtime = MagicMock(spec=ActionExecutionClient) | |
| # Mock the runtime creation to set up the runtime attribute | |
| async def mock_create_runtime(*args, **kwargs): | |
| session.runtime = mock_runtime | |
| return True | |
| session._create_runtime = AsyncMock(side_effect=mock_create_runtime) | |
| # Create a mock EventStream with some events | |
| mock_event_stream = MagicMock(spec=EventStream) | |
| mock_event_stream.get_events.return_value = [] | |
| mock_event_stream.subscribe = MagicMock() | |
| mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist | |
| # Inject the mock event stream into the session | |
| session.event_stream = mock_event_stream | |
| # Create a mock restored state | |
| mock_restored_state = MagicMock(spec=State) | |
| mock_restored_state.start_id = -1 | |
| mock_restored_state.end_id = -1 | |
| mock_restored_state.max_iterations = 5 | |
| # Create a spy on set_initial_state by subclassing AgentController | |
| class SpyAgentController(AgentController): | |
| set_initial_state_call_count = 0 | |
| test_initial_state = None | |
| def set_initial_state(self, *args, state=None, **kwargs): | |
| self.set_initial_state_call_count += 1 | |
| self.test_initial_state = state | |
| super().set_initial_state(*args, state=state, **kwargs) | |
| # create a mock Memory | |
| mock_memory = MagicMock(spec=Memory) | |
| # Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession | |
| with ( | |
| patch( | |
| 'openhands.server.session.agent_session.AgentController', SpyAgentController | |
| ), | |
| patch( | |
| 'openhands.server.session.agent_session.EventStream', | |
| return_value=mock_event_stream, | |
| ), | |
| patch( | |
| 'openhands.controller.state.state.State.restore_from_session', | |
| return_value=mock_restored_state, | |
| ), | |
| patch('openhands.server.session.agent_session.Memory', mock_memory), | |
| ): | |
| await session.start( | |
| runtime_name='test-runtime', | |
| config=OpenHandsConfig(), | |
| agent=mock_agent, | |
| max_iterations=10, | |
| ) | |
| # Verify set_initial_state was called once with the restored state | |
| assert session.controller.set_initial_state_call_count == 1 | |
| # Verify EventStream.subscribe was called with correct parameters | |
| mock_event_stream.subscribe.assert_called_with( | |
| EventStreamSubscriber.AGENT_CONTROLLER, | |
| session.controller.on_event, | |
| session.controller.id, | |
| ) | |
| assert session.controller.test_initial_state is mock_restored_state | |
| assert session.controller.state is mock_restored_state | |
| assert session.controller.state.max_iterations == 5 | |
| assert session.controller.state.start_id == 0 | |
| assert session.controller.state.end_id == -1 | |