Spaces:
Build error
Build error
| from pydantic import BaseModel, Field | |
| from openhands.core.logger import openhands_logger as logger | |
| from openhands.events.action import ( | |
| Action, | |
| ChangeAgentStateAction, | |
| MessageAction, | |
| NullAction, | |
| ) | |
| from openhands.events.event import EventSource | |
| from openhands.events.observation import ( | |
| AgentStateChangedObservation, | |
| NullObservation, | |
| Observation, | |
| ) | |
| from openhands.events.serialization.event import event_to_dict | |
| from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput | |
| TraceElement = Message | ToolCall | ToolOutput | Function | |
| def get_next_id(trace: list[TraceElement]) -> str: | |
| used_ids = [el.id for el in trace if isinstance(el, ToolCall)] | |
| for i in range(1, len(used_ids) + 2): | |
| if str(i) not in used_ids: | |
| return str(i) | |
| return '1' | |
| def get_last_id( | |
| trace: list[TraceElement], | |
| ) -> str | None: | |
| for el in reversed(trace): | |
| if isinstance(el, ToolCall): | |
| return el.id | |
| return None | |
| def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]: | |
| next_id = get_next_id(trace) | |
| inv_trace: list[TraceElement] = [] | |
| if isinstance(action, MessageAction): | |
| if action.source == EventSource.USER: | |
| inv_trace.append(Message(role='user', content=action.content)) | |
| else: | |
| inv_trace.append(Message(role='assistant', content=action.content)) | |
| elif isinstance(action, (NullAction, ChangeAgentStateAction)): | |
| pass | |
| elif hasattr(action, 'action') and action.action is not None: | |
| event_dict = event_to_dict(action) | |
| args = event_dict.get('args', {}) | |
| thought = args.pop('thought', None) | |
| function = Function(name=action.action, arguments=args) | |
| if thought is not None: | |
| inv_trace.append(Message(role='assistant', content=thought)) | |
| inv_trace.append(ToolCall(id=next_id, type='function', function=function)) | |
| else: | |
| logger.error(f'Unknown action type: {type(action)}') | |
| return inv_trace | |
| def parse_observation( | |
| trace: list[TraceElement], obs: Observation | |
| ) -> list[TraceElement]: | |
| last_id = get_last_id(trace) | |
| if isinstance(obs, (NullObservation, AgentStateChangedObservation)): | |
| return [] | |
| elif hasattr(obs, 'content') and obs.content is not None: | |
| return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)] | |
| else: | |
| logger.error(f'Unknown observation type: {type(obs)}') | |
| return [] | |
| def parse_element( | |
| trace: list[TraceElement], element: Action | Observation | |
| ) -> list[TraceElement]: | |
| if isinstance(element, Action): | |
| return parse_action(trace, element) | |
| return parse_observation(trace, element) | |
| def parse_trace(trace: list[tuple[Action, Observation]]) -> list[TraceElement]: | |
| inv_trace: list[TraceElement] = [] | |
| for action, obs in trace: | |
| inv_trace.extend(parse_action(inv_trace, action)) | |
| inv_trace.extend(parse_observation(inv_trace, obs)) | |
| return inv_trace | |
| class InvariantState(BaseModel): | |
| trace: list[TraceElement] = Field(default_factory=list) | |
| def add_action(self, action: Action) -> None: | |
| self.trace.extend(parse_action(self.trace, action)) | |
| def add_observation(self, obs: Observation) -> None: | |
| self.trace.extend(parse_observation(self.trace, obs)) | |
| def concatenate(self, other: 'InvariantState') -> None: | |
| self.trace.extend(other.trace) | |