Spaces:
Paused
Paused
| """ | |
| KV Storage integration for OpenManus | |
| Provides interface to Cloudflare KV operations | |
| """ | |
| import json | |
| from typing import Any, Dict, List, Optional | |
| from app.logger import logger | |
| from .client import CloudflareClient, CloudflareError | |
| class KVStorage: | |
| """Cloudflare KV Storage client""" | |
| def __init__( | |
| self, | |
| client: CloudflareClient, | |
| sessions_namespace_id: str, | |
| cache_namespace_id: str, | |
| ): | |
| self.client = client | |
| self.sessions_namespace_id = sessions_namespace_id | |
| self.cache_namespace_id = cache_namespace_id | |
| self.base_endpoint = f"accounts/{client.account_id}/storage/kv/namespaces" | |
| def _get_namespace_id(self, namespace_type: str) -> str: | |
| """Get namespace ID based on type""" | |
| if namespace_type == "cache": | |
| return self.cache_namespace_id | |
| return self.sessions_namespace_id | |
| async def set_value( | |
| self, | |
| key: str, | |
| value: Any, | |
| namespace_type: str = "sessions", | |
| ttl: Optional[int] = None, | |
| use_worker: bool = True, | |
| ) -> Dict[str, Any]: | |
| """Set a value in KV store""" | |
| namespace_id = self._get_namespace_id(namespace_type) | |
| # Serialize value to JSON | |
| if isinstance(value, (dict, list)): | |
| serialized_value = json.dumps(value) | |
| elif isinstance(value, str): | |
| serialized_value = value | |
| else: | |
| serialized_value = json.dumps(value) | |
| try: | |
| if use_worker: | |
| set_data = { | |
| "key": key, | |
| "value": serialized_value, | |
| "namespace": namespace_type, | |
| } | |
| if ttl: | |
| set_data["ttl"] = ttl | |
| response = await self.client.post( | |
| f"api/kv/set", data=set_data, use_worker=True | |
| ) | |
| else: | |
| # Use KV API directly | |
| params = {} | |
| if ttl: | |
| params["expiration_ttl"] = ttl | |
| query_string = "&".join([f"{k}={v}" for k, v in params.items()]) | |
| endpoint = f"{self.base_endpoint}/{namespace_id}/values/{key}" | |
| if query_string: | |
| endpoint += f"?{query_string}" | |
| response = await self.client.put( | |
| endpoint, data={"value": serialized_value} | |
| ) | |
| return { | |
| "success": True, | |
| "key": key, | |
| "namespace": namespace_type, | |
| "ttl": ttl, | |
| **response, | |
| } | |
| except CloudflareError as e: | |
| logger.error(f"KV set value failed: {e}") | |
| raise | |
| async def get_value( | |
| self, | |
| key: str, | |
| namespace_type: str = "sessions", | |
| parse_json: bool = True, | |
| use_worker: bool = True, | |
| ) -> Optional[Any]: | |
| """Get a value from KV store""" | |
| namespace_id = self._get_namespace_id(namespace_type) | |
| try: | |
| if use_worker: | |
| response = await self.client.get( | |
| f"api/kv/get/{key}?namespace={namespace_type}", use_worker=True | |
| ) | |
| if response and "value" in response: | |
| value = response["value"] | |
| if parse_json and isinstance(value, str): | |
| try: | |
| return json.loads(value) | |
| except json.JSONDecodeError: | |
| return value | |
| return value | |
| else: | |
| response = await self.client.get( | |
| f"{self.base_endpoint}/{namespace_id}/values/{key}" | |
| ) | |
| # KV API returns the value directly as text | |
| value = ( | |
| response.get("result", {}).get("value") | |
| if "result" in response | |
| else response | |
| ) | |
| if value and parse_json and isinstance(value, str): | |
| try: | |
| return json.loads(value) | |
| except json.JSONDecodeError: | |
| return value | |
| return value | |
| except CloudflareError as e: | |
| if e.status_code == 404: | |
| return None | |
| logger.error(f"KV get value failed: {e}") | |
| raise | |
| return None | |
| async def delete_value( | |
| self, key: str, namespace_type: str = "sessions", use_worker: bool = True | |
| ) -> Dict[str, Any]: | |
| """Delete a value from KV store""" | |
| namespace_id = self._get_namespace_id(namespace_type) | |
| try: | |
| if use_worker: | |
| response = await self.client.delete( | |
| f"api/kv/delete/{key}?namespace={namespace_type}", use_worker=True | |
| ) | |
| else: | |
| response = await self.client.delete( | |
| f"{self.base_endpoint}/{namespace_id}/values/{key}" | |
| ) | |
| return { | |
| "success": True, | |
| "key": key, | |
| "namespace": namespace_type, | |
| **response, | |
| } | |
| except CloudflareError as e: | |
| logger.error(f"KV delete value failed: {e}") | |
| raise | |
| async def list_keys( | |
| self, | |
| namespace_type: str = "sessions", | |
| prefix: str = "", | |
| limit: int = 1000, | |
| use_worker: bool = True, | |
| ) -> Dict[str, Any]: | |
| """List keys in KV namespace""" | |
| namespace_id = self._get_namespace_id(namespace_type) | |
| try: | |
| if use_worker: | |
| params = {"namespace": namespace_type, "prefix": prefix, "limit": limit} | |
| query_string = "&".join([f"{k}={v}" for k, v in params.items() if v]) | |
| response = await self.client.get( | |
| f"api/kv/list?{query_string}", use_worker=True | |
| ) | |
| else: | |
| params = {"prefix": prefix, "limit": limit} | |
| query_string = "&".join([f"{k}={v}" for k, v in params.items() if v]) | |
| response = await self.client.get( | |
| f"{self.base_endpoint}/{namespace_id}/keys?{query_string}" | |
| ) | |
| return { | |
| "namespace": namespace_type, | |
| "prefix": prefix, | |
| "keys": ( | |
| response.get("result", []) | |
| if "result" in response | |
| else response.get("keys", []) | |
| ), | |
| **response, | |
| } | |
| except CloudflareError as e: | |
| logger.error(f"KV list keys failed: {e}") | |
| raise | |
| # Session-specific methods | |
| async def set_session( | |
| self, | |
| session_id: str, | |
| session_data: Dict[str, Any], | |
| ttl: int = 86400, # 24 hours default | |
| ) -> Dict[str, Any]: | |
| """Set session data""" | |
| data = { | |
| **session_data, | |
| "created_at": session_data.get("created_at", int(time.time())), | |
| "expires_at": int(time.time()) + ttl, | |
| } | |
| return await self.set_value(f"session:{session_id}", data, "sessions", ttl) | |
| async def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """Get session data""" | |
| session = await self.get_value(f"session:{session_id}", "sessions") | |
| if session and isinstance(session, dict): | |
| # Check if session is expired | |
| expires_at = session.get("expires_at") | |
| if expires_at and int(time.time()) > expires_at: | |
| await self.delete_session(session_id) | |
| return None | |
| return session | |
| async def delete_session(self, session_id: str) -> Dict[str, Any]: | |
| """Delete session data""" | |
| return await self.delete_value(f"session:{session_id}", "sessions") | |
| async def update_session( | |
| self, session_id: str, updates: Dict[str, Any], extend_ttl: Optional[int] = None | |
| ) -> Dict[str, Any]: | |
| """Update session data""" | |
| existing_session = await self.get_session(session_id) | |
| if not existing_session: | |
| raise CloudflareError("Session not found") | |
| updated_data = {**existing_session, **updates, "updated_at": int(time.time())} | |
| # Calculate TTL | |
| ttl = None | |
| if extend_ttl: | |
| ttl = extend_ttl | |
| elif existing_session.get("expires_at"): | |
| ttl = max(0, existing_session["expires_at"] - int(time.time())) | |
| return await self.set_session(session_id, updated_data, ttl or 86400) | |
| # Cache-specific methods | |
| async def set_cache( | |
| self, key: str, data: Any, ttl: int = 3600 # 1 hour default | |
| ) -> Dict[str, Any]: | |
| """Set cache data""" | |
| cache_data = { | |
| "data": data, | |
| "cached_at": int(time.time()), | |
| "expires_at": int(time.time()) + ttl, | |
| } | |
| return await self.set_value(f"cache:{key}", cache_data, "cache", ttl) | |
| async def get_cache(self, key: str) -> Optional[Any]: | |
| """Get cache data""" | |
| cached = await self.get_value(f"cache:{key}", "cache") | |
| if cached and isinstance(cached, dict): | |
| # Check if cache is expired | |
| expires_at = cached.get("expires_at") | |
| if expires_at and int(time.time()) > expires_at: | |
| await self.delete_cache(key) | |
| return None | |
| return cached.get("data") | |
| return cached | |
| async def delete_cache(self, key: str) -> Dict[str, Any]: | |
| """Delete cache data""" | |
| return await self.delete_value(f"cache:{key}", "cache") | |
| # User-specific methods | |
| async def set_user_cache( | |
| self, user_id: str, key: str, data: Any, ttl: int = 3600 | |
| ) -> Dict[str, Any]: | |
| """Set user-specific cache""" | |
| user_key = f"user:{user_id}:{key}" | |
| return await self.set_cache(user_key, data, ttl) | |
| async def get_user_cache(self, user_id: str, key: str) -> Optional[Any]: | |
| """Get user-specific cache""" | |
| user_key = f"user:{user_id}:{key}" | |
| return await self.get_cache(user_key) | |
| async def delete_user_cache(self, user_id: str, key: str) -> Dict[str, Any]: | |
| """Delete user-specific cache""" | |
| user_key = f"user:{user_id}:{key}" | |
| return await self.delete_cache(user_key) | |
| async def get_user_cache_keys(self, user_id: str, limit: int = 100) -> List[str]: | |
| """Get all cache keys for a user""" | |
| result = await self.list_keys("cache", f"cache:user:{user_id}:", limit) | |
| keys = [] | |
| for key_info in result.get("keys", []): | |
| if isinstance(key_info, dict): | |
| key = key_info.get("name", "") | |
| else: | |
| key = str(key_info) | |
| # Remove prefix to get the actual key | |
| if key.startswith(f"cache:user:{user_id}:"): | |
| clean_key = key.replace(f"cache:user:{user_id}:", "") | |
| keys.append(clean_key) | |
| return keys | |
| # Conversation caching | |
| async def cache_conversation( | |
| self, | |
| conversation_id: str, | |
| messages: List[Dict[str, Any]], | |
| ttl: int = 7200, # 2 hours default | |
| ) -> Dict[str, Any]: | |
| """Cache conversation messages""" | |
| return await self.set_cache( | |
| f"conversation:{conversation_id}", | |
| {"messages": messages, "last_updated": int(time.time())}, | |
| ttl, | |
| ) | |
| async def get_cached_conversation( | |
| self, conversation_id: str | |
| ) -> Optional[Dict[str, Any]]: | |
| """Get cached conversation""" | |
| return await self.get_cache(f"conversation:{conversation_id}") | |
| # Agent execution caching | |
| async def cache_agent_execution( | |
| self, execution_id: str, execution_data: Dict[str, Any], ttl: int = 3600 | |
| ) -> Dict[str, Any]: | |
| """Cache agent execution data""" | |
| return await self.set_cache(f"execution:{execution_id}", execution_data, ttl) | |
| async def get_cached_agent_execution( | |
| self, execution_id: str | |
| ) -> Optional[Dict[str, Any]]: | |
| """Get cached agent execution""" | |
| return await self.get_cache(f"execution:{execution_id}") | |
| # Batch operations | |
| async def set_batch( | |
| self, | |
| items: List[Dict[str, Any]], | |
| namespace_type: str = "cache", | |
| ttl: Optional[int] = None, | |
| ) -> Dict[str, Any]: | |
| """Set multiple values (simulated batch operation)""" | |
| results = [] | |
| successful = 0 | |
| failed = 0 | |
| for item in items: | |
| try: | |
| key = item["key"] | |
| value = item["value"] | |
| item_ttl = item.get("ttl", ttl) | |
| result = await self.set_value(key, value, namespace_type, item_ttl) | |
| results.append({"key": key, "success": True, "result": result}) | |
| successful += 1 | |
| except Exception as e: | |
| results.append( | |
| {"key": item.get("key"), "success": False, "error": str(e)} | |
| ) | |
| failed += 1 | |
| return { | |
| "success": failed == 0, | |
| "successful": successful, | |
| "failed": failed, | |
| "total": len(items), | |
| "results": results, | |
| } | |
| async def get_batch( | |
| self, keys: List[str], namespace_type: str = "cache" | |
| ) -> Dict[str, Any]: | |
| """Get multiple values (simulated batch operation)""" | |
| results = {} | |
| for key in keys: | |
| try: | |
| value = await self.get_value(key, namespace_type) | |
| results[key] = value | |
| except Exception as e: | |
| logger.error(f"Failed to get key {key}: {e}") | |
| results[key] = None | |
| return results | |
| def _hash_params(self, params: Dict[str, Any]) -> str: | |
| """Create a hash for cache keys from parameters""" | |
| if not params: | |
| return "no-params" | |
| # Simple hash function for cache keys | |
| import hashlib | |
| params_str = json.dumps(params, sort_keys=True) | |
| return hashlib.md5(params_str.encode()).hexdigest()[:16] | |
| # Add time import at the top | |
| import time | |