Spaces:
Build error
Build error
| from __future__ import annotations | |
| import json | |
| from typing import Any | |
| from pydantic import BaseModel, Field | |
| from openhands.core.config.condenser_config import ( | |
| StructuredSummaryCondenserConfig, | |
| ) | |
| from openhands.core.logger import openhands_logger as logger | |
| from openhands.core.message import Message, TextContent | |
| from openhands.events.action.agent import CondensationAction | |
| from openhands.events.observation.agent import AgentCondensationObservation | |
| from openhands.events.serialization.event import truncate_content | |
| from openhands.llm import LLM | |
| from openhands.memory.condenser.condenser import ( | |
| Condensation, | |
| RollingCondenser, | |
| View, | |
| ) | |
| class StateSummary(BaseModel): | |
| """A structured representation summarizing the state of the agent and the task.""" | |
| # Required core fields | |
| user_context: str = Field( | |
| default='', | |
| description='Essential user requirements, goals, and clarifications in concise form.', | |
| ) | |
| completed_tasks: str = Field( | |
| default='', description='List of tasks completed so far with brief results.' | |
| ) | |
| pending_tasks: str = Field( | |
| default='', description='List of tasks that still need to be done.' | |
| ) | |
| current_state: str = Field( | |
| default='', | |
| description='Current variables, data structures, or other relevant state information.', | |
| ) | |
| # Code state fields | |
| files_modified: str = Field( | |
| default='', description='List of files that have been created or modified.' | |
| ) | |
| function_changes: str = Field( | |
| default='', description='List of functions that have been created or modified.' | |
| ) | |
| data_structures: str = Field( | |
| default='', description='List of key data structures in use or modified.' | |
| ) | |
| # Test status fields | |
| tests_written: str = Field( | |
| default='', | |
| description='Whether tests have been written for the changes. True, false, or unknown.', | |
| ) | |
| tests_passing: str = Field( | |
| default='', | |
| description='Whether all tests are currently passing. True, false, or unknown.', | |
| ) | |
| failing_tests: str = Field( | |
| default='', description='List of names or descriptions of any failing tests.' | |
| ) | |
| error_messages: str = Field( | |
| default='', description='List of key error messages encountered.' | |
| ) | |
| # Version control fields | |
| branch_created: str = Field( | |
| default='', | |
| description='Whether a branch has been created for this work. True, false, or unknown.', | |
| ) | |
| branch_name: str = Field( | |
| default='', description='Name of the current working branch if known.' | |
| ) | |
| commits_made: str = Field( | |
| default='', | |
| description='Whether any commits have been made. True, false, or unknown.', | |
| ) | |
| pr_created: str = Field( | |
| default='', | |
| description='Whether a pull request has been created. True, false, or unknown.', | |
| ) | |
| pr_status: str = Field( | |
| default='', | |
| description="Status of any pull request: 'draft', 'open', 'merged', 'closed', or 'unknown'.", | |
| ) | |
| # Other fields | |
| dependencies: str = Field( | |
| default='', | |
| description='List of dependencies or imports that have been added or modified.', | |
| ) | |
| other_relevant_context: str = Field( | |
| default='', | |
| description="Any other important information that doesn't fit into the categories above.", | |
| ) | |
| def tool_description(cls) -> dict[str, Any]: | |
| """Description of a tool whose arguments are the fields of this class. | |
| Can be given to an LLM to force structured generation. | |
| """ | |
| properties = {} | |
| # Build properties dictionary from field information | |
| for field_name, field in cls.model_fields.items(): | |
| description = field.description or '' | |
| properties[field_name] = {'type': 'string', 'description': description} | |
| return { | |
| 'type': 'function', | |
| 'function': { | |
| 'name': 'create_state_summary', | |
| 'description': 'Creates a comprehensive summary of the current state of the interaction to preserve context when history grows too large. You must include non-empty values for user_context, completed_tasks, and pending_tasks.', | |
| 'parameters': { | |
| 'type': 'object', | |
| 'properties': properties, | |
| 'required': ['user_context', 'completed_tasks', 'pending_tasks'], | |
| }, | |
| }, | |
| } | |
| def __str__(self) -> str: | |
| """Format the state summary in a clear way for Claude 3.7 Sonnet.""" | |
| sections = [ | |
| '# State Summary', | |
| '## Core Information', | |
| f'**User Context**: {self.user_context}', | |
| f'**Completed Tasks**: {self.completed_tasks}', | |
| f'**Pending Tasks**: {self.pending_tasks}', | |
| f'**Current State**: {self.current_state}', | |
| '## Code Changes', | |
| f'**Files Modified**: {self.files_modified}', | |
| f'**Function Changes**: {self.function_changes}', | |
| f'**Data Structures**: {self.data_structures}', | |
| f'**Dependencies**: {self.dependencies}', | |
| '## Testing Status', | |
| f'**Tests Written**: {self.tests_written}', | |
| f'**Tests Passing**: {self.tests_passing}', | |
| f'**Failing Tests**: {self.failing_tests}', | |
| f'**Error Messages**: {self.error_messages}', | |
| '## Version Control', | |
| f'**Branch Created**: {self.branch_created}', | |
| f'**Branch Name**: {self.branch_name}', | |
| f'**Commits Made**: {self.commits_made}', | |
| f'**PR Created**: {self.pr_created}', | |
| f'**PR Status**: {self.pr_status}', | |
| '## Additional Context', | |
| f'**Other Relevant Context**: {self.other_relevant_context}', | |
| ] | |
| # Join all sections with double newlines | |
| return '\n\n'.join(sections) | |
| class StructuredSummaryCondenser(RollingCondenser): | |
| """A condenser that summarizes forgotten events. | |
| Maintains a condensed history and forgets old events when it grows too large. Uses structured generation via function-calling to produce summaries that replace forgotten events. | |
| """ | |
| def __init__( | |
| self, | |
| llm: LLM, | |
| max_size: int = 100, | |
| keep_first: int = 1, | |
| max_event_length: int = 10_000, | |
| ): | |
| if keep_first >= max_size // 2: | |
| raise ValueError( | |
| f'keep_first ({keep_first}) must be less than half of max_size ({max_size})' | |
| ) | |
| if keep_first < 0: | |
| raise ValueError(f'keep_first ({keep_first}) cannot be negative') | |
| if max_size < 1: | |
| raise ValueError(f'max_size ({max_size}) cannot be non-positive') | |
| if not llm.is_function_calling_active(): | |
| raise ValueError( | |
| 'LLM must support function calling to use StructuredSummaryCondenser' | |
| ) | |
| self.max_size = max_size | |
| self.keep_first = keep_first | |
| self.max_event_length = max_event_length | |
| self.llm = llm | |
| super().__init__() | |
| def _truncate(self, content: str) -> str: | |
| """Truncate the content to fit within the specified maximum event length.""" | |
| return truncate_content(content, max_chars=self.max_event_length) | |
| def get_condensation(self, view: View) -> Condensation: | |
| head = view[: self.keep_first] | |
| target_size = self.max_size // 2 | |
| # Number of events to keep from the tail -- target size, minus however many | |
| # prefix events from the head, minus one for the summarization event | |
| events_from_tail = target_size - len(head) - 1 | |
| summary_event = ( | |
| view[self.keep_first] | |
| if isinstance(view[self.keep_first], AgentCondensationObservation) | |
| else AgentCondensationObservation('No events summarized') | |
| ) | |
| # Identify events to be forgotten (those not in head or tail) | |
| forgotten_events = [] | |
| for event in view[self.keep_first : -events_from_tail]: | |
| if not isinstance(event, AgentCondensationObservation): | |
| forgotten_events.append(event) | |
| # Construct prompt for summarization | |
| prompt = """You are maintaining a context-aware state summary for an interactive software agent. This summary is critical because it: | |
| 1. Preserves essential context when conversation history grows too large | |
| 2. Prevents lost work when the session length exceeds token limits | |
| 3. Helps maintain continuity across multiple interactions | |
| You will be given: | |
| - A list of events (actions taken by the agent) | |
| - The most recent previous summary (if one exists) | |
| Capture all relevant information, especially: | |
| - User requirements that were explicitly stated | |
| - Work that has been completed | |
| - Tasks that remain pending | |
| - Current state of code, variables, and data structures | |
| - The status of any version control operations""" | |
| prompt += '\n\n' | |
| # Add the previous summary if it exists. We'll always have a summary | |
| # event, but the types aren't precise enought to guarantee that it has a | |
| # message attribute. | |
| summary_event_content = self._truncate( | |
| summary_event.message if summary_event.message else '' | |
| ) | |
| prompt += f'<PREVIOUS SUMMARY>\n{summary_event_content}\n</PREVIOUS SUMMARY>\n' | |
| prompt += '\n\n' | |
| # Add all events that are being forgotten. We use the string | |
| # representation defined by the event, and truncate it if necessary. | |
| for forgotten_event in forgotten_events: | |
| event_content = self._truncate(str(forgotten_event)) | |
| prompt += f'<EVENT id={forgotten_event.id}>\n{event_content}\n</EVENT>\n' | |
| messages = [Message(role='user', content=[TextContent(text=prompt)])] | |
| response = self.llm.completion( | |
| messages=self.llm.format_messages_for_llm(messages), | |
| tools=[StateSummary.tool_description()], | |
| tool_choice={ | |
| 'type': 'function', | |
| 'function': {'name': 'create_state_summary'}, | |
| }, | |
| ) | |
| try: | |
| # Extract the message containing tool calls | |
| message = response.choices[0].message | |
| # Check if there are tool calls | |
| if not hasattr(message, 'tool_calls') or not message.tool_calls: | |
| raise ValueError('No tool calls found in response') | |
| # Find the create_state_summary tool call | |
| summary_tool_call = None | |
| for tool_call in message.tool_calls: | |
| if tool_call.function.name == 'create_state_summary': | |
| summary_tool_call = tool_call | |
| break | |
| if not summary_tool_call: | |
| raise ValueError('create_state_summary tool call not found') | |
| # Parse the arguments | |
| args_json = summary_tool_call.function.arguments | |
| args_dict = json.loads(args_json) | |
| # Create a StateSummary object | |
| summary = StateSummary.model_validate(args_dict) | |
| except (ValueError, AttributeError, KeyError, json.JSONDecodeError) as e: | |
| logger.warning( | |
| f'Failed to parse summary tool call: {e}. Using empty summary.' | |
| ) | |
| summary = StateSummary() | |
| self.add_metadata('response', response.model_dump()) | |
| self.add_metadata('metrics', self.llm.metrics.get()) | |
| return Condensation( | |
| action=CondensationAction( | |
| forgotten_events_start_id=min(event.id for event in forgotten_events), | |
| forgotten_events_end_id=max(event.id for event in forgotten_events), | |
| summary=str(summary), | |
| summary_offset=self.keep_first, | |
| ) | |
| ) | |
| def should_condense(self, view: View) -> bool: | |
| return len(view) > self.max_size | |
| def from_config( | |
| cls, config: StructuredSummaryCondenserConfig | |
| ) -> StructuredSummaryCondenser: | |
| # This condenser cannot take advantage of prompt caching. If it happens | |
| # to be set, we'll pay for the cache writes but never get a chance to | |
| # save on a read. | |
| llm_config = config.llm_config.model_copy() | |
| llm_config.caching_prompt = False | |
| return StructuredSummaryCondenser( | |
| llm=LLM(config=llm_config), | |
| max_size=config.max_size, | |
| keep_first=config.keep_first, | |
| max_event_length=config.max_event_length, | |
| ) | |
| StructuredSummaryCondenser.register_config(StructuredSummaryCondenserConfig) | |