Speedofmastery's picture
Deploy OpenManus Complete AI Platform - 200+ Models + Mobile Auth + Cloudflare Services
d94d354
raw
history blame
14.2 kB
"""
KV Storage integration for OpenManus
Provides interface to Cloudflare KV operations
"""
import json
from typing import Any, Dict, List, Optional
from app.logger import logger
from .client import CloudflareClient, CloudflareError
class KVStorage:
"""Cloudflare KV Storage client"""
def __init__(
self,
client: CloudflareClient,
sessions_namespace_id: str,
cache_namespace_id: str,
):
self.client = client
self.sessions_namespace_id = sessions_namespace_id
self.cache_namespace_id = cache_namespace_id
self.base_endpoint = f"accounts/{client.account_id}/storage/kv/namespaces"
def _get_namespace_id(self, namespace_type: str) -> str:
"""Get namespace ID based on type"""
if namespace_type == "cache":
return self.cache_namespace_id
return self.sessions_namespace_id
async def set_value(
self,
key: str,
value: Any,
namespace_type: str = "sessions",
ttl: Optional[int] = None,
use_worker: bool = True,
) -> Dict[str, Any]:
"""Set a value in KV store"""
namespace_id = self._get_namespace_id(namespace_type)
# Serialize value to JSON
if isinstance(value, (dict, list)):
serialized_value = json.dumps(value)
elif isinstance(value, str):
serialized_value = value
else:
serialized_value = json.dumps(value)
try:
if use_worker:
set_data = {
"key": key,
"value": serialized_value,
"namespace": namespace_type,
}
if ttl:
set_data["ttl"] = ttl
response = await self.client.post(
f"api/kv/set", data=set_data, use_worker=True
)
else:
# Use KV API directly
params = {}
if ttl:
params["expiration_ttl"] = ttl
query_string = "&".join([f"{k}={v}" for k, v in params.items()])
endpoint = f"{self.base_endpoint}/{namespace_id}/values/{key}"
if query_string:
endpoint += f"?{query_string}"
response = await self.client.put(
endpoint, data={"value": serialized_value}
)
return {
"success": True,
"key": key,
"namespace": namespace_type,
"ttl": ttl,
**response,
}
except CloudflareError as e:
logger.error(f"KV set value failed: {e}")
raise
async def get_value(
self,
key: str,
namespace_type: str = "sessions",
parse_json: bool = True,
use_worker: bool = True,
) -> Optional[Any]:
"""Get a value from KV store"""
namespace_id = self._get_namespace_id(namespace_type)
try:
if use_worker:
response = await self.client.get(
f"api/kv/get/{key}?namespace={namespace_type}", use_worker=True
)
if response and "value" in response:
value = response["value"]
if parse_json and isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
return value
return value
else:
response = await self.client.get(
f"{self.base_endpoint}/{namespace_id}/values/{key}"
)
# KV API returns the value directly as text
value = (
response.get("result", {}).get("value")
if "result" in response
else response
)
if value and parse_json and isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
return value
return value
except CloudflareError as e:
if e.status_code == 404:
return None
logger.error(f"KV get value failed: {e}")
raise
return None
async def delete_value(
self, key: str, namespace_type: str = "sessions", use_worker: bool = True
) -> Dict[str, Any]:
"""Delete a value from KV store"""
namespace_id = self._get_namespace_id(namespace_type)
try:
if use_worker:
response = await self.client.delete(
f"api/kv/delete/{key}?namespace={namespace_type}", use_worker=True
)
else:
response = await self.client.delete(
f"{self.base_endpoint}/{namespace_id}/values/{key}"
)
return {
"success": True,
"key": key,
"namespace": namespace_type,
**response,
}
except CloudflareError as e:
logger.error(f"KV delete value failed: {e}")
raise
async def list_keys(
self,
namespace_type: str = "sessions",
prefix: str = "",
limit: int = 1000,
use_worker: bool = True,
) -> Dict[str, Any]:
"""List keys in KV namespace"""
namespace_id = self._get_namespace_id(namespace_type)
try:
if use_worker:
params = {"namespace": namespace_type, "prefix": prefix, "limit": limit}
query_string = "&".join([f"{k}={v}" for k, v in params.items() if v])
response = await self.client.get(
f"api/kv/list?{query_string}", use_worker=True
)
else:
params = {"prefix": prefix, "limit": limit}
query_string = "&".join([f"{k}={v}" for k, v in params.items() if v])
response = await self.client.get(
f"{self.base_endpoint}/{namespace_id}/keys?{query_string}"
)
return {
"namespace": namespace_type,
"prefix": prefix,
"keys": (
response.get("result", [])
if "result" in response
else response.get("keys", [])
),
**response,
}
except CloudflareError as e:
logger.error(f"KV list keys failed: {e}")
raise
# Session-specific methods
async def set_session(
self,
session_id: str,
session_data: Dict[str, Any],
ttl: int = 86400, # 24 hours default
) -> Dict[str, Any]:
"""Set session data"""
data = {
**session_data,
"created_at": session_data.get("created_at", int(time.time())),
"expires_at": int(time.time()) + ttl,
}
return await self.set_value(f"session:{session_id}", data, "sessions", ttl)
async def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get session data"""
session = await self.get_value(f"session:{session_id}", "sessions")
if session and isinstance(session, dict):
# Check if session is expired
expires_at = session.get("expires_at")
if expires_at and int(time.time()) > expires_at:
await self.delete_session(session_id)
return None
return session
async def delete_session(self, session_id: str) -> Dict[str, Any]:
"""Delete session data"""
return await self.delete_value(f"session:{session_id}", "sessions")
async def update_session(
self, session_id: str, updates: Dict[str, Any], extend_ttl: Optional[int] = None
) -> Dict[str, Any]:
"""Update session data"""
existing_session = await self.get_session(session_id)
if not existing_session:
raise CloudflareError("Session not found")
updated_data = {**existing_session, **updates, "updated_at": int(time.time())}
# Calculate TTL
ttl = None
if extend_ttl:
ttl = extend_ttl
elif existing_session.get("expires_at"):
ttl = max(0, existing_session["expires_at"] - int(time.time()))
return await self.set_session(session_id, updated_data, ttl or 86400)
# Cache-specific methods
async def set_cache(
self, key: str, data: Any, ttl: int = 3600 # 1 hour default
) -> Dict[str, Any]:
"""Set cache data"""
cache_data = {
"data": data,
"cached_at": int(time.time()),
"expires_at": int(time.time()) + ttl,
}
return await self.set_value(f"cache:{key}", cache_data, "cache", ttl)
async def get_cache(self, key: str) -> Optional[Any]:
"""Get cache data"""
cached = await self.get_value(f"cache:{key}", "cache")
if cached and isinstance(cached, dict):
# Check if cache is expired
expires_at = cached.get("expires_at")
if expires_at and int(time.time()) > expires_at:
await self.delete_cache(key)
return None
return cached.get("data")
return cached
async def delete_cache(self, key: str) -> Dict[str, Any]:
"""Delete cache data"""
return await self.delete_value(f"cache:{key}", "cache")
# User-specific methods
async def set_user_cache(
self, user_id: str, key: str, data: Any, ttl: int = 3600
) -> Dict[str, Any]:
"""Set user-specific cache"""
user_key = f"user:{user_id}:{key}"
return await self.set_cache(user_key, data, ttl)
async def get_user_cache(self, user_id: str, key: str) -> Optional[Any]:
"""Get user-specific cache"""
user_key = f"user:{user_id}:{key}"
return await self.get_cache(user_key)
async def delete_user_cache(self, user_id: str, key: str) -> Dict[str, Any]:
"""Delete user-specific cache"""
user_key = f"user:{user_id}:{key}"
return await self.delete_cache(user_key)
async def get_user_cache_keys(self, user_id: str, limit: int = 100) -> List[str]:
"""Get all cache keys for a user"""
result = await self.list_keys("cache", f"cache:user:{user_id}:", limit)
keys = []
for key_info in result.get("keys", []):
if isinstance(key_info, dict):
key = key_info.get("name", "")
else:
key = str(key_info)
# Remove prefix to get the actual key
if key.startswith(f"cache:user:{user_id}:"):
clean_key = key.replace(f"cache:user:{user_id}:", "")
keys.append(clean_key)
return keys
# Conversation caching
async def cache_conversation(
self,
conversation_id: str,
messages: List[Dict[str, Any]],
ttl: int = 7200, # 2 hours default
) -> Dict[str, Any]:
"""Cache conversation messages"""
return await self.set_cache(
f"conversation:{conversation_id}",
{"messages": messages, "last_updated": int(time.time())},
ttl,
)
async def get_cached_conversation(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""Get cached conversation"""
return await self.get_cache(f"conversation:{conversation_id}")
# Agent execution caching
async def cache_agent_execution(
self, execution_id: str, execution_data: Dict[str, Any], ttl: int = 3600
) -> Dict[str, Any]:
"""Cache agent execution data"""
return await self.set_cache(f"execution:{execution_id}", execution_data, ttl)
async def get_cached_agent_execution(
self, execution_id: str
) -> Optional[Dict[str, Any]]:
"""Get cached agent execution"""
return await self.get_cache(f"execution:{execution_id}")
# Batch operations
async def set_batch(
self,
items: List[Dict[str, Any]],
namespace_type: str = "cache",
ttl: Optional[int] = None,
) -> Dict[str, Any]:
"""Set multiple values (simulated batch operation)"""
results = []
successful = 0
failed = 0
for item in items:
try:
key = item["key"]
value = item["value"]
item_ttl = item.get("ttl", ttl)
result = await self.set_value(key, value, namespace_type, item_ttl)
results.append({"key": key, "success": True, "result": result})
successful += 1
except Exception as e:
results.append(
{"key": item.get("key"), "success": False, "error": str(e)}
)
failed += 1
return {
"success": failed == 0,
"successful": successful,
"failed": failed,
"total": len(items),
"results": results,
}
async def get_batch(
self, keys: List[str], namespace_type: str = "cache"
) -> Dict[str, Any]:
"""Get multiple values (simulated batch operation)"""
results = {}
for key in keys:
try:
value = await self.get_value(key, namespace_type)
results[key] = value
except Exception as e:
logger.error(f"Failed to get key {key}: {e}")
results[key] = None
return results
def _hash_params(self, params: Dict[str, Any]) -> str:
"""Create a hash for cache keys from parameters"""
if not params:
return "no-params"
# Simple hash function for cache keys
import hashlib
params_str = json.dumps(params, sort_keys=True)
return hashlib.md5(params_str.encode()).hexdigest()[:16]
# Add time import at the top
import time