Spaces:
Build error
Build error
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from unittest.mock import AsyncMock, MagicMock, Mock | |
| from uuid import uuid4 | |
| import pytest | |
| from openhands.controller.agent import Agent | |
| from openhands.controller.agent_controller import AgentController | |
| from openhands.controller.state.state import State | |
| from openhands.core.config import LLMConfig | |
| from openhands.core.config.agent_config import AgentConfig | |
| from openhands.core.schema import AgentState | |
| from openhands.events import EventSource, EventStream | |
| from openhands.events.action import ( | |
| AgentDelegateAction, | |
| AgentFinishAction, | |
| MessageAction, | |
| ) | |
| from openhands.events.action.agent import RecallAction | |
| from openhands.events.event import Event, RecallType | |
| from openhands.events.observation.agent import RecallObservation | |
| from openhands.events.stream import EventStreamSubscriber | |
| from openhands.llm.llm import LLM | |
| from openhands.llm.metrics import Metrics | |
| from openhands.memory.memory import Memory | |
| from openhands.storage.memory import InMemoryFileStore | |
| def mock_event_stream(): | |
| """Creates an event stream in memory.""" | |
| sid = f'test-{uuid4()}' | |
| file_store = InMemoryFileStore({}) | |
| return EventStream(sid=sid, file_store=file_store) | |
| def mock_parent_agent(): | |
| """Creates a mock parent agent for testing delegation.""" | |
| agent = MagicMock(spec=Agent) | |
| agent.name = 'ParentAgent' | |
| agent.llm = MagicMock(spec=LLM) | |
| agent.llm.metrics = Metrics() | |
| agent.llm.config = LLMConfig() | |
| agent.config = AgentConfig() | |
| # Add a proper system message mock | |
| from openhands.events.action.message import SystemMessageAction | |
| system_message = SystemMessageAction(content='Test system message') | |
| system_message._source = EventSource.AGENT | |
| system_message._id = -1 # Set invalid ID to avoid the ID check | |
| agent.get_system_message.return_value = system_message | |
| return agent | |
| def mock_child_agent(): | |
| """Creates a mock child agent for testing delegation.""" | |
| agent = MagicMock(spec=Agent) | |
| agent.name = 'ChildAgent' | |
| agent.llm = MagicMock(spec=LLM) | |
| agent.llm.metrics = Metrics() | |
| agent.llm.config = LLMConfig() | |
| agent.config = AgentConfig() | |
| # Add a proper system message mock | |
| from openhands.events.action.message import SystemMessageAction | |
| system_message = SystemMessageAction(content='Test system message') | |
| system_message._source = EventSource.AGENT | |
| system_message._id = -1 # Set invalid ID to avoid the ID check | |
| agent.get_system_message.return_value = system_message | |
| return agent | |
| async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream): | |
| """ | |
| Test that when the parent agent delegates to a child, the parent's delegate | |
| is set, and once the child finishes, the parent is cleaned up properly. | |
| """ | |
| # Mock the agent class resolution so that AgentController can instantiate mock_child_agent | |
| Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent) | |
| # Create parent controller | |
| parent_state = State(max_iterations=10) | |
| parent_controller = AgentController( | |
| agent=mock_parent_agent, | |
| event_stream=mock_event_stream, | |
| max_iterations=10, | |
| sid='parent', | |
| confirmation_mode=False, | |
| headless_mode=True, | |
| initial_state=parent_state, | |
| ) | |
| # Setup Memory to catch RecallActions | |
| mock_memory = MagicMock(spec=Memory) | |
| mock_memory.event_stream = mock_event_stream | |
| def on_event(event: Event): | |
| if isinstance(event, RecallAction): | |
| # create a RecallObservation | |
| microagent_observation = RecallObservation( | |
| recall_type=RecallType.KNOWLEDGE, | |
| content='Found info', | |
| ) | |
| microagent_observation._cause = event.id # ignore attr-defined warning | |
| mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT) | |
| mock_memory.on_event = on_event | |
| mock_event_stream.subscribe( | |
| EventStreamSubscriber.MEMORY, mock_memory.on_event, mock_memory | |
| ) | |
| # Setup a delegate action from the parent | |
| delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True}) | |
| mock_parent_agent.step.return_value = delegate_action | |
| # Simulate a user message event to cause parent.step() to run | |
| message_action = MessageAction(content='please delegate now') | |
| message_action._source = EventSource.USER | |
| await parent_controller._on_event(message_action) | |
| # Give time for the async step() to execute | |
| await asyncio.sleep(1) | |
| # Verify that a RecallObservation was added to the event stream | |
| events = list(mock_event_stream.get_events()) | |
| # SystemMessageAction, RecallAction, AgentChangeState, AgentDelegateAction, SystemMessageAction (for child) | |
| assert mock_event_stream.get_latest_event_id() == 5 | |
| # a RecallObservation and an AgentDelegateAction should be in the list | |
| assert any(isinstance(event, RecallObservation) for event in events) | |
| assert any(isinstance(event, AgentDelegateAction) for event in events) | |
| # Verify that a delegate agent controller is created | |
| assert parent_controller.delegate is not None, ( | |
| "Parent's delegate controller was not set." | |
| ) | |
| # The parent's iteration should have incremented | |
| assert parent_controller.state.iteration == 1, ( | |
| 'Parent iteration should be incremented after step.' | |
| ) | |
| # Now simulate that the child increments local iteration and finishes its subtask | |
| delegate_controller = parent_controller.delegate | |
| delegate_controller.state.iteration = 5 # child had some steps | |
| delegate_controller.state.outputs = {'delegate_result': 'done'} | |
| # The child is done, so we simulate it finishing: | |
| child_finish_action = AgentFinishAction() | |
| await delegate_controller._on_event(child_finish_action) | |
| await asyncio.sleep(0.5) | |
| # Now the parent's delegate is None | |
| assert parent_controller.delegate is None, ( | |
| 'Parent delegate should be None after child finishes.' | |
| ) | |
| # Parent's global iteration is updated from the child | |
| assert parent_controller.state.iteration == 6, ( | |
| "Parent iteration should be the child's iteration + 1 after child is done." | |
| ) | |
| # Cleanup | |
| await parent_controller.close() | |
| async def test_delegate_step_different_states( | |
| mock_parent_agent, mock_event_stream, delegate_state | |
| ): | |
| """Ensure that delegate is closed or remains open based on the delegate's state.""" | |
| controller = AgentController( | |
| agent=mock_parent_agent, | |
| event_stream=mock_event_stream, | |
| max_iterations=10, | |
| sid='test', | |
| confirmation_mode=False, | |
| headless_mode=True, | |
| ) | |
| mock_delegate = AsyncMock() | |
| controller.delegate = mock_delegate | |
| mock_delegate.state.iteration = 5 | |
| mock_delegate.state.outputs = {'result': 'test'} | |
| mock_delegate.agent.name = 'TestDelegate' | |
| mock_delegate.get_agent_state = Mock(return_value=delegate_state) | |
| mock_delegate._step = AsyncMock() | |
| mock_delegate.close = AsyncMock() | |
| def call_on_event_with_new_loop(): | |
| """ | |
| In this thread, create and set a fresh event loop, so that the run_until_complete() | |
| calls inside controller.on_event(...) find a valid loop. | |
| """ | |
| loop_in_thread = asyncio.new_event_loop() | |
| try: | |
| asyncio.set_event_loop(loop_in_thread) | |
| msg_action = MessageAction(content='Test message') | |
| msg_action._source = EventSource.USER | |
| controller.on_event(msg_action) | |
| finally: | |
| loop_in_thread.close() | |
| loop = asyncio.get_running_loop() | |
| with ThreadPoolExecutor() as executor: | |
| future = loop.run_in_executor(executor, call_on_event_with_new_loop) | |
| await future | |
| if delegate_state == AgentState.RUNNING: | |
| assert controller.delegate is not None | |
| assert controller.state.iteration == 0 | |
| mock_delegate.close.assert_not_called() | |
| else: | |
| assert controller.delegate is None | |
| assert controller.state.iteration == 5 | |
| # The close method is called once in end_delegate | |
| assert mock_delegate.close.call_count == 1 | |
| await controller.close() | |