Spaces:
Build error
Build error
| import asyncio | |
| import queue | |
| import threading | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import datetime | |
| from enum import Enum | |
| from functools import partial | |
| from typing import Any, Callable | |
| from openhands.core.logger import openhands_logger as logger | |
| from openhands.events.event import Event, EventSource | |
| from openhands.events.event_store import EventStore | |
| from openhands.events.serialization.event import event_from_dict, event_to_dict | |
| from openhands.io import json | |
| from openhands.storage import FileStore | |
| from openhands.storage.locations import ( | |
| get_conversation_dir, | |
| ) | |
| from openhands.utils.async_utils import call_sync_from_async | |
| from openhands.utils.shutdown_listener import should_continue | |
| class EventStreamSubscriber(str, Enum): | |
| AGENT_CONTROLLER = 'agent_controller' | |
| SECURITY_ANALYZER = 'security_analyzer' | |
| RESOLVER = 'openhands_resolver' | |
| SERVER = 'server' | |
| RUNTIME = 'runtime' | |
| MEMORY = 'memory' | |
| MAIN = 'main' | |
| TEST = 'test' | |
| async def session_exists( | |
| sid: str, file_store: FileStore, user_id: str | None = None | |
| ) -> bool: | |
| try: | |
| await call_sync_from_async(file_store.list, get_conversation_dir(sid, user_id)) | |
| return True | |
| except FileNotFoundError: | |
| return False | |
| class EventStream(EventStore): | |
| secrets: dict[str, str] | |
| # For each subscriber ID, there is a map of callback functions - useful | |
| # when there are multiple listeners | |
| _subscribers: dict[str, dict[str, Callable]] | |
| _lock: threading.Lock | |
| _queue: queue.Queue[Event] | |
| _queue_thread: threading.Thread | |
| _queue_loop: asyncio.AbstractEventLoop | None | |
| _thread_pools: dict[str, dict[str, ThreadPoolExecutor]] | |
| _thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]] | |
| _write_page_cache: list[dict] | |
| def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None): | |
| super().__init__(sid, file_store, user_id) | |
| self._stop_flag = threading.Event() | |
| self._queue: queue.Queue[Event] = queue.Queue() | |
| self._thread_pools = {} | |
| self._thread_loops = {} | |
| self._queue_loop = None | |
| self._queue_thread = threading.Thread(target=self._run_queue_loop) | |
| self._queue_thread.daemon = True | |
| self._queue_thread.start() | |
| self._subscribers = {} | |
| self._lock = threading.Lock() | |
| self.secrets = {} | |
| self._write_page_cache = [] | |
| def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| if subscriber_id not in self._thread_loops: | |
| self._thread_loops[subscriber_id] = {} | |
| self._thread_loops[subscriber_id][callback_id] = loop | |
| def close(self) -> None: | |
| self._stop_flag.set() | |
| if self._queue_thread.is_alive(): | |
| self._queue_thread.join() | |
| subscriber_ids = list(self._subscribers.keys()) | |
| for subscriber_id in subscriber_ids: | |
| callback_ids = list(self._subscribers[subscriber_id].keys()) | |
| for callback_id in callback_ids: | |
| self._clean_up_subscriber(subscriber_id, callback_id) | |
| # Clear queue | |
| while not self._queue.empty(): | |
| self._queue.get() | |
| def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None: | |
| if subscriber_id not in self._subscribers: | |
| logger.warning(f'Subscriber not found during cleanup: {subscriber_id}') | |
| return | |
| if callback_id not in self._subscribers[subscriber_id]: | |
| logger.warning(f'Callback not found during cleanup: {callback_id}') | |
| return | |
| if ( | |
| subscriber_id in self._thread_loops | |
| and callback_id in self._thread_loops[subscriber_id] | |
| ): | |
| loop = self._thread_loops[subscriber_id][callback_id] | |
| current_task = asyncio.current_task(loop) | |
| pending = [ | |
| task for task in asyncio.all_tasks(loop) if task is not current_task | |
| ] | |
| for task in pending: | |
| task.cancel() | |
| try: | |
| loop.stop() | |
| loop.close() | |
| except Exception as e: | |
| logger.warning( | |
| f'Error closing loop for {subscriber_id}/{callback_id}: {e}' | |
| ) | |
| del self._thread_loops[subscriber_id][callback_id] | |
| if ( | |
| subscriber_id in self._thread_pools | |
| and callback_id in self._thread_pools[subscriber_id] | |
| ): | |
| pool = self._thread_pools[subscriber_id][callback_id] | |
| pool.shutdown() | |
| del self._thread_pools[subscriber_id][callback_id] | |
| del self._subscribers[subscriber_id][callback_id] | |
| def subscribe( | |
| self, | |
| subscriber_id: EventStreamSubscriber, | |
| callback: Callable[[Event], None], | |
| callback_id: str, | |
| ) -> None: | |
| initializer = partial(self._init_thread_loop, subscriber_id, callback_id) | |
| pool = ThreadPoolExecutor(max_workers=1, initializer=initializer) | |
| if subscriber_id not in self._subscribers: | |
| self._subscribers[subscriber_id] = {} | |
| self._thread_pools[subscriber_id] = {} | |
| if callback_id in self._subscribers[subscriber_id]: | |
| raise ValueError( | |
| f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}' | |
| ) | |
| self._subscribers[subscriber_id][callback_id] = callback | |
| self._thread_pools[subscriber_id][callback_id] = pool | |
| def unsubscribe( | |
| self, subscriber_id: EventStreamSubscriber, callback_id: str | |
| ) -> None: | |
| if subscriber_id not in self._subscribers: | |
| logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}') | |
| return | |
| if callback_id not in self._subscribers[subscriber_id]: | |
| logger.warning(f'Callback not found during unsubscribe: {callback_id}') | |
| return | |
| self._clean_up_subscriber(subscriber_id, callback_id) | |
| def add_event(self, event: Event, source: EventSource) -> None: | |
| if event.id != Event.INVALID_ID: | |
| raise ValueError( | |
| f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.' | |
| ) | |
| event._timestamp = datetime.now().isoformat() | |
| event._source = source # type: ignore [attr-defined] | |
| with self._lock: | |
| event._id = self.cur_id # type: ignore [attr-defined] | |
| self.cur_id += 1 | |
| # Take a copy of the current write page | |
| current_write_page = self._write_page_cache | |
| data = event_to_dict(event) | |
| data = self._replace_secrets(data) | |
| event = event_from_dict(data) | |
| current_write_page.append(data) | |
| # If the page is full, create a new page for future events / other threads to use | |
| if len(current_write_page) == self.cache_size: | |
| self._write_page_cache = [] | |
| if event.id is not None: | |
| # Write the event to the store - this can take some time | |
| self.file_store.write( | |
| self._get_filename_for_id(event.id, self.user_id), json.dumps(data) | |
| ) | |
| # Store the cache page last - if it is not present during reads then it will simply be bypassed. | |
| self._store_cache_page(current_write_page) | |
| self._queue.put(event) | |
| def _store_cache_page(self, current_write_page: list[dict]): | |
| """Store a page in the cache. Reading individual events is slow when there are a lot of them, so we use pages.""" | |
| if len(current_write_page) < self.cache_size: | |
| return | |
| start = current_write_page[0]['id'] | |
| end = start + self.cache_size | |
| contents = json.dumps(current_write_page) | |
| cache_filename = self._get_filename_for_cache(start, end) | |
| self.file_store.write(cache_filename, contents) | |
| def set_secrets(self, secrets: dict[str, str]) -> None: | |
| self.secrets = secrets.copy() | |
| def update_secrets(self, secrets: dict[str, str]) -> None: | |
| self.secrets.update(secrets) | |
| def _replace_secrets(self, data: dict[str, Any]) -> dict[str, Any]: | |
| for key in data: | |
| if isinstance(data[key], dict): | |
| data[key] = self._replace_secrets(data[key]) | |
| elif isinstance(data[key], str): | |
| for secret in self.secrets.values(): | |
| data[key] = data[key].replace(secret, '<secret_hidden>') | |
| return data | |
| def _run_queue_loop(self) -> None: | |
| self._queue_loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(self._queue_loop) | |
| try: | |
| self._queue_loop.run_until_complete(self._process_queue()) | |
| finally: | |
| self._queue_loop.close() | |
| async def _process_queue(self) -> None: | |
| while should_continue() and not self._stop_flag.is_set(): | |
| event = None | |
| try: | |
| event = self._queue.get(timeout=0.1) | |
| except queue.Empty: | |
| continue | |
| # pass each event to each callback in order | |
| for key in sorted(self._subscribers.keys()): | |
| callbacks = self._subscribers[key] | |
| # Create a copy of the keys to avoid "dictionary changed size during iteration" error | |
| callback_ids = list(callbacks.keys()) | |
| for callback_id in callback_ids: | |
| # Check if callback_id still exists (might have been removed during iteration) | |
| if callback_id in callbacks: | |
| callback = callbacks[callback_id] | |
| pool = self._thread_pools[key][callback_id] | |
| future = pool.submit(callback, event) | |
| future.add_done_callback( | |
| self._make_error_handler(callback_id, key) | |
| ) | |
| def _make_error_handler( | |
| self, callback_id: str, subscriber_id: str | |
| ) -> Callable[[Any], None]: | |
| def _handle_callback_error(fut: Any) -> None: | |
| try: | |
| # This will raise any exception that occurred during callback execution | |
| fut.result() | |
| except Exception as e: | |
| logger.error( | |
| f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}', | |
| ) | |
| # Re-raise in the main thread so the error is not swallowed | |
| raise e | |
| return _handle_callback_error | |