Spaces:
Build error
Build error
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from openhands.controller.agent import Agent | |
| from openhands.controller.agent_controller import AgentController | |
| from openhands.controller.state.state import State | |
| from openhands.core.config import OpenHandsConfig | |
| from openhands.events import EventSource | |
| from openhands.events.action import CmdRunAction, MessageAction, RecallAction | |
| from openhands.events.action.message import SystemMessageAction | |
| from openhands.events.event import RecallType | |
| from openhands.events.observation import ( | |
| CmdOutputObservation, | |
| Observation, | |
| RecallObservation, | |
| ) | |
| from openhands.events.stream import EventStream | |
| from openhands.llm.llm import LLM | |
| from openhands.llm.metrics import Metrics | |
| from openhands.storage.memory import InMemoryFileStore | |
| # Helper function to create events with sequential IDs and causes | |
| def create_events(event_data): | |
| events = [] | |
| # Import necessary types here to avoid repeated imports inside the loop | |
| from openhands.events.action import CmdRunAction, RecallAction | |
| from openhands.events.observation import CmdOutputObservation, RecallObservation | |
| for i, data in enumerate(event_data): | |
| event_type = data['type'] | |
| source = data.get('source', EventSource.AGENT) | |
| kwargs = {} # Arguments for the event constructor | |
| # Determine arguments based on event type | |
| if event_type == RecallAction: | |
| kwargs['query'] = data.get('query', '') | |
| kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE) | |
| elif event_type == RecallObservation: | |
| kwargs['content'] = data.get('content', '') | |
| kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE) | |
| elif event_type == CmdRunAction: | |
| kwargs['command'] = data.get('command', '') | |
| elif event_type == CmdOutputObservation: | |
| # Required args for CmdOutputObservation | |
| kwargs['content'] = data.get('content', '') | |
| kwargs['command'] = data.get('command', '') | |
| # Pass command_id via kwargs if present in data | |
| if 'command_id' in data: | |
| kwargs['command_id'] = data['command_id'] | |
| # Pass metadata if present | |
| if 'metadata' in data: | |
| kwargs['metadata'] = data['metadata'] | |
| else: # Default for MessageAction, SystemMessageAction, etc. | |
| kwargs['content'] = data.get('content', '') | |
| # Instantiate the event | |
| event = event_type(**kwargs) | |
| # Assign internal attributes AFTER instantiation | |
| event._id = i + 1 # Assign sequential IDs starting from 1 | |
| event._source = source | |
| # Assign _cause using cause_id from data, AFTER event._id is set | |
| if 'cause_id' in data: | |
| event._cause = data['cause_id'] | |
| # If command_id was NOT passed via kwargs but cause_id exists, | |
| # pass cause_id as command_id to __init__ via kwargs for legacy handling | |
| # This needs to happen *before* instantiation if we want __init__ to handle it | |
| # Let's adjust the logic slightly: | |
| if event_type == CmdOutputObservation: | |
| if 'command_id' not in kwargs and 'cause_id' in data: | |
| kwargs['command_id'] = data['cause_id'] # Let __init__ handle this | |
| # Re-instantiate if we added command_id | |
| if 'command_id' in kwargs and event.command_id != kwargs['command_id']: | |
| event = event_type(**kwargs) | |
| event._id = i + 1 | |
| event._source = source | |
| # Now assign _cause if it exists in data, after potential re-instantiation | |
| if 'cause_id' in data: | |
| event._cause = data['cause_id'] | |
| events.append(event) | |
| return events | |
| def controller_fixture(): | |
| mock_agent = MagicMock(spec=Agent) | |
| mock_agent.llm = MagicMock(spec=LLM) | |
| mock_agent.llm.metrics = Metrics() | |
| mock_agent.llm.config = OpenHandsConfig().get_llm_config() | |
| mock_agent.config = OpenHandsConfig().get_agent_config('CodeActAgent') | |
| mock_event_stream = MagicMock(spec=EventStream) | |
| mock_event_stream.sid = 'test_sid' | |
| mock_event_stream.file_store = InMemoryFileStore({}) | |
| # Ensure get_latest_event_id returns an integer | |
| mock_event_stream.get_latest_event_id.return_value = -1 | |
| controller = AgentController( | |
| agent=mock_agent, | |
| event_stream=mock_event_stream, | |
| max_iterations=10, | |
| sid='test_sid', | |
| ) | |
| controller.state = State(session_id='test_sid') | |
| # Mock _first_user_message directly on the instance | |
| mock_first_user_message = MagicMock(spec=MessageAction) | |
| controller._first_user_message = MagicMock(return_value=mock_first_user_message) | |
| return controller, mock_first_user_message | |
| # ============================================= | |
| # Test Cases for _apply_conversation_window | |
| # ============================================= | |
| def test_basic_truncation(controller_fixture): | |
| controller, mock_first_user_message = controller_fixture | |
| controller.state.history = create_events( | |
| [ | |
| {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 1', | |
| 'source': EventSource.USER, | |
| }, # 2 | |
| {'type': RecallAction, 'query': 'User Task 1'}, # 3 | |
| {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 | |
| {'type': CmdRunAction, 'command': 'ls'}, # 5 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'file1', | |
| 'command': 'ls', | |
| 'cause_id': 5, | |
| }, # 6 | |
| {'type': CmdRunAction, 'command': 'pwd'}, # 7 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': '/dir', | |
| 'command': 'pwd', | |
| 'cause_id': 7, | |
| }, # 8 | |
| {'type': CmdRunAction, 'command': 'cat file1'}, # 9 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'content', | |
| 'command': 'cat file1', | |
| 'cause_id': 9, | |
| }, # 10 | |
| ] | |
| ) | |
| mock_first_user_message.id = 2 # Set the ID of the mocked first user message | |
| # Calculation (RecallAction now essential): | |
| # History len = 10 | |
| # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) | |
| # Non-essential count = 10 - 4 = 6 | |
| # num_recent_to_keep = max(1, 6 // 2) = 3 | |
| # slice_start_index = 10 - 3 = 7 | |
| # recent_events_slice = history[7:] = [obs2(8), cmd3(9), obs3(10)] | |
| # Validation: remove leading obs2(8). validated_slice = [cmd3(9), obs3(10)] | |
| # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd3(9), obs3(10)] | |
| # Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6. | |
| truncated_events = controller._apply_conversation_window() | |
| assert len(truncated_events) == 6 | |
| expected_ids = [1, 2, 3, 4, 9, 10] | |
| actual_ids = [e.id for e in truncated_events] | |
| assert actual_ids == expected_ids | |
| # Check no dangling observations at the start of the recent slice part | |
| # The first event of the validated slice is cmd3(9) | |
| assert not isinstance(truncated_events[4], Observation) # Index adjusted | |
| def test_no_system_message(controller_fixture): | |
| controller, mock_first_user_message = controller_fixture | |
| controller.state.history = create_events( | |
| [ | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 1', | |
| 'source': EventSource.USER, | |
| }, # 1 | |
| {'type': RecallAction, 'query': 'User Task 1'}, # 2 | |
| {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 2}, # 3 | |
| {'type': CmdRunAction, 'command': 'ls'}, # 4 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'file1', | |
| 'command': 'ls', | |
| 'cause_id': 4, | |
| }, # 5 | |
| {'type': CmdRunAction, 'command': 'pwd'}, # 6 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': '/dir', | |
| 'command': 'pwd', | |
| 'cause_id': 6, | |
| }, # 7 | |
| {'type': CmdRunAction, 'command': 'cat file1'}, # 8 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'content', | |
| 'command': 'cat file1', | |
| 'cause_id': 8, | |
| }, # 9 | |
| ] | |
| ) | |
| mock_first_user_message.id = 1 | |
| # Calculation (RecallAction now essential): | |
| # History len = 9 | |
| # Essentials = [user(1), recall_act(2), recall_obs(3)] (len=3) | |
| # Non-essential count = 9 - 3 = 6 | |
| # num_recent_to_keep = max(1, 6 // 2) = 3 | |
| # slice_start_index = 9 - 3 = 6 | |
| # recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)] | |
| # Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)] | |
| # Final = essentials + validated_slice = [user(1), recall_act(2), recall_obs(3), cmd3(8), obs3(9)] | |
| # Expected IDs: [1, 2, 3, 8, 9]. Length 5. | |
| truncated_events = controller._apply_conversation_window() | |
| assert len(truncated_events) == 5 | |
| expected_ids = [1, 2, 3, 8, 9] | |
| actual_ids = [e.id for e in truncated_events] | |
| assert actual_ids == expected_ids | |
| def test_no_recall_observation(controller_fixture): | |
| controller, mock_first_user_message = controller_fixture | |
| controller.state.history = create_events( | |
| [ | |
| {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 1', | |
| 'source': EventSource.USER, | |
| }, # 2 | |
| {'type': RecallAction, 'query': 'User Task 1'}, # 3 (Recall Action exists) | |
| # Recall Observation is missing | |
| {'type': CmdRunAction, 'command': 'ls'}, # 4 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'file1', | |
| 'command': 'ls', | |
| 'cause_id': 4, | |
| }, # 5 | |
| {'type': CmdRunAction, 'command': 'pwd'}, # 6 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': '/dir', | |
| 'command': 'pwd', | |
| 'cause_id': 6, | |
| }, # 7 | |
| {'type': CmdRunAction, 'command': 'cat file1'}, # 8 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'content', | |
| 'command': 'cat file1', | |
| 'cause_id': 8, | |
| }, # 9 | |
| ] | |
| ) | |
| mock_first_user_message.id = 2 | |
| # Calculation (RecallAction essential only if RecallObs exists): | |
| # History len = 9 | |
| # Essentials = [sys(1), user(2)] (len=2) - RecallObs missing, so RecallAction not essential here | |
| # Non-essential count = 9 - 2 = 7 | |
| # num_recent_to_keep = max(1, 7 // 2) = 3 | |
| # slice_start_index = 9 - 3 = 6 | |
| # recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)] | |
| # Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)] | |
| # Final = essentials + validated_slice = [sys(1), user(2), recall_action(3), cmd_cat(8), obs_cat(9)] | |
| # Expected IDs: [1, 2, 3, 8, 9]. Length 5. | |
| truncated_events = controller._apply_conversation_window() | |
| assert len(truncated_events) == 5 | |
| expected_ids = [1, 2, 3, 8, 9] | |
| actual_ids = [e.id for e in truncated_events] | |
| assert actual_ids == expected_ids | |
| def test_short_history_no_truncation(controller_fixture): | |
| controller, mock_first_user_message = controller_fixture | |
| history = create_events( | |
| [ | |
| {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 1', | |
| 'source': EventSource.USER, | |
| }, # 2 | |
| {'type': RecallAction, 'query': 'User Task 1'}, # 3 | |
| {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 | |
| {'type': CmdRunAction, 'command': 'ls'}, # 5 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'file1', | |
| 'command': 'ls', | |
| 'cause_id': 5, | |
| }, # 6 | |
| ] | |
| ) | |
| controller.state.history = history | |
| mock_first_user_message.id = 2 | |
| # Calculation (RecallAction now essential): | |
| # History len = 6 | |
| # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) | |
| # Non-essential count = 6 - 4 = 2 | |
| # num_recent_to_keep = max(1, 2 // 2) = 1 | |
| # slice_start_index = 6 - 1 = 5 | |
| # recent_events_slice = history[5:] = [obs1(6)] | |
| # Validation: remove leading obs1(6). validated_slice = [] | |
| # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)] | |
| # Expected IDs: [1, 2, 3, 4]. Length 4. | |
| truncated_events = controller._apply_conversation_window() | |
| assert len(truncated_events) == 4 | |
| expected_ids = [1, 2, 3, 4] | |
| actual_ids = [e.id for e in truncated_events] | |
| assert actual_ids == expected_ids | |
| def test_only_essential_events(controller_fixture): | |
| controller, mock_first_user_message = controller_fixture | |
| history = create_events( | |
| [ | |
| {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 1', | |
| 'source': EventSource.USER, | |
| }, # 2 | |
| {'type': RecallAction, 'query': 'User Task 1'}, # 3 | |
| {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 | |
| ] | |
| ) | |
| controller.state.history = history | |
| mock_first_user_message.id = 2 | |
| # Calculation (RecallAction now essential): | |
| # History len = 4 | |
| # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) | |
| # Non-essential count = 4 - 4 = 0 | |
| # num_recent_to_keep = max(1, 0 // 2) = 1 | |
| # slice_start_index = 4 - 1 = 3 | |
| # recent_events_slice = history[3:] = [recall_obs(4)] | |
| # Validation: remove leading recall_obs(4). validated_slice = [] | |
| # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)] | |
| # Expected IDs: [1, 2, 3, 4]. Length 4. | |
| truncated_events = controller._apply_conversation_window() | |
| assert len(truncated_events) == 4 | |
| expected_ids = [1, 2, 3, 4] | |
| actual_ids = [e.id for e in truncated_events] | |
| assert actual_ids == expected_ids | |
| def test_dangling_observations_at_cut_point(controller_fixture): | |
| controller, mock_first_user_message = controller_fixture | |
| history_forced_dangle = create_events( | |
| [ | |
| {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 1', | |
| 'source': EventSource.USER, | |
| }, # 2 | |
| {'type': RecallAction, 'query': 'User Task 1'}, # 3 | |
| {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 | |
| # --- Slice calculation should start here --- | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'dangle1', | |
| 'command': 'cmd_unknown', | |
| }, # 5 (Dangling) | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'dangle2', | |
| 'command': 'cmd_unknown', | |
| }, # 6 (Dangling) | |
| {'type': CmdRunAction, 'command': 'cmd1'}, # 7 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'obs1', | |
| 'command': 'cmd1', | |
| 'cause_id': 7, | |
| }, # 8 | |
| {'type': CmdRunAction, 'command': 'cmd2'}, # 9 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'obs2', | |
| 'command': 'cmd2', | |
| 'cause_id': 9, | |
| }, # 10 | |
| ] | |
| ) # 10 events total | |
| controller.state.history = history_forced_dangle | |
| mock_first_user_message.id = 2 | |
| # Calculation (RecallAction now essential): | |
| # History len = 10 | |
| # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) | |
| # Non-essential count = 10 - 4 = 6 | |
| # num_recent_to_keep = max(1, 6 // 2) = 3 | |
| # slice_start_index = 10 - 3 = 7 | |
| # recent_events_slice = history[7:] = [obs1(8), cmd2(9), obs2(10)] | |
| # Validation: remove leading obs1(8). validated_slice = [cmd2(9), obs2(10)] | |
| # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd2(9), obs2(10)] | |
| # Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6. | |
| truncated_events = controller._apply_conversation_window() | |
| assert len(truncated_events) == 6 | |
| expected_ids = [1, 2, 3, 4, 9, 10] | |
| actual_ids = [e.id for e in truncated_events] | |
| assert actual_ids == expected_ids | |
| # Verify dangling observations 5 and 6 were removed (implicitly by slice start and validation) | |
| def test_only_dangling_observations_in_recent_slice(controller_fixture): | |
| controller, mock_first_user_message = controller_fixture | |
| history = create_events( | |
| [ | |
| {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 1', | |
| 'source': EventSource.USER, | |
| }, # 2 | |
| {'type': RecallAction, 'query': 'User Task 1'}, # 3 | |
| {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 | |
| # --- Slice calculation should start here --- | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'dangle1', | |
| 'command': 'cmd_unknown', | |
| }, # 5 (Dangling) | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'dangle2', | |
| 'command': 'cmd_unknown', | |
| }, # 6 (Dangling) | |
| ] | |
| ) # 6 events total | |
| controller.state.history = history | |
| mock_first_user_message.id = 2 | |
| # Calculation (RecallAction now essential): | |
| # History len = 6 | |
| # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) | |
| # Non-essential count = 6 - 4 = 2 | |
| # num_recent_to_keep = max(1, 2 // 2) = 1 | |
| # slice_start_index = 6 - 1 = 5 | |
| # recent_events_slice = history[5:] = [dangle2(6)] | |
| # Validation: remove leading dangle2(6). validated_slice = [] (Corrected based on user feedback/bugfix) | |
| # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)] | |
| # Expected IDs: [1, 2, 3, 4]. Length 4. | |
| with patch( | |
| 'openhands.controller.agent_controller.logger.warning' | |
| ) as mock_log_warning: | |
| truncated_events = controller._apply_conversation_window() | |
| assert len(truncated_events) == 4 | |
| expected_ids = [1, 2, 3, 4] | |
| actual_ids = [e.id for e in truncated_events] | |
| assert actual_ids == expected_ids | |
| # Verify dangling observations 5 and 6 were removed | |
| # Check that the specific warning was logged exactly once | |
| assert mock_log_warning.call_count == 1 | |
| # Check the essential parts of the arguments, allowing for variations like stacklevel | |
| call_args, call_kwargs = mock_log_warning.call_args | |
| expected_message_substring = 'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.' | |
| assert expected_message_substring in call_args[0] | |
| assert 'extra' in call_kwargs | |
| assert call_kwargs['extra'].get('session_id') == 'test_sid' | |
| def test_empty_history(controller_fixture): | |
| controller, _ = controller_fixture | |
| controller.state.history = [] | |
| truncated_events = controller._apply_conversation_window() | |
| assert truncated_events == [] | |
| def test_multiple_user_messages(controller_fixture): | |
| controller, mock_first_user_message = controller_fixture | |
| history = create_events( | |
| [ | |
| {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 1', | |
| 'source': EventSource.USER, | |
| }, # 2 (First) | |
| {'type': RecallAction, 'query': 'User Task 1'}, # 3 | |
| { | |
| 'type': RecallObservation, | |
| 'content': 'Recall result 1', | |
| 'cause_id': 3, | |
| }, # 4 | |
| {'type': CmdRunAction, 'command': 'cmd1'}, # 5 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'obs1', | |
| 'command': 'cmd1', | |
| 'cause_id': 5, | |
| }, # 6 | |
| { | |
| 'type': MessageAction, | |
| 'content': 'User Task 2', | |
| 'source': EventSource.USER, | |
| }, # 7 (Second) | |
| {'type': RecallAction, 'query': 'User Task 2'}, # 8 | |
| { | |
| 'type': RecallObservation, | |
| 'content': 'Recall result 2', | |
| 'cause_id': 8, | |
| }, # 9 | |
| {'type': CmdRunAction, 'command': 'cmd2'}, # 10 | |
| { | |
| 'type': CmdOutputObservation, | |
| 'content': 'obs2', | |
| 'command': 'cmd2', | |
| 'cause_id': 10, | |
| }, # 11 | |
| ] | |
| ) # 11 events total | |
| controller.state.history = history | |
| mock_first_user_message.id = 2 # Explicitly set the first user message ID | |
| # Calculation (RecallAction now essential): | |
| # History len = 11 | |
| # Essentials = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] (len=4) | |
| # Non-essential count = 11 - 4 = 7 | |
| # num_recent_to_keep = max(1, 7 // 2) = 3 | |
| # slice_start_index = 11 - 3 = 8 | |
| # recent_events_slice = history[8:] = [recall_obs2(9), cmd2(10), obs2(11)] | |
| # Validation: remove leading recall_obs2(9). validated_slice = [cmd2(10), obs2(11)] | |
| # Final = essentials + validated_slice = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] + [cmd2(10), obs2(11)] | |
| # Expected IDs: [1, 2, 3, 4, 10, 11]. Length 6. | |
| truncated_events = controller._apply_conversation_window() | |
| assert len(truncated_events) == 6 | |
| expected_ids = [1, 2, 3, 4, 10, 11] | |
| actual_ids = [e.id for e in truncated_events] | |
| assert actual_ids == expected_ids | |
| # Verify the second user message (ID 7) was NOT kept | |
| assert not any(event.id == 7 for event in truncated_events) | |
| # Verify the first user message (ID 2) is present | |
| assert any(event.id == 2 for event in truncated_events) | |