Spaces:
Sleeping
Sleeping
| import os | |
| import hashlib | |
| import secrets | |
| from typing import Dict, Optional | |
| from datetime import datetime, timedelta | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class APIKeyManager: | |
| """Manages API key authentication and rate limiting""" | |
| def __init__(self): | |
| self.api_keys = { | |
| os.getenv("API_KEY_1", "your-secure-api-key-1"): { | |
| "user": "user1", | |
| "created": datetime.now(), | |
| "last_used": None, | |
| "request_count": 0 | |
| }, | |
| os.getenv("API_KEY_2", "your-secure-api-key-2"): { | |
| "user": "user2", | |
| "created": datetime.now(), | |
| "last_used": None, | |
| "request_count": 0 | |
| } | |
| } | |
| self.rate_limits = {} # {api_key: {minute: count}} | |
| self.max_requests_per_minute = int(os.getenv("RATE_LIMIT", "10")) | |
| def validate_api_key(self, api_key: str) -> Optional[str]: | |
| """Validate API key and return user info""" | |
| if api_key in self.api_keys: | |
| self.api_keys[api_key]["last_used"] = datetime.now() | |
| self.api_keys[api_key]["request_count"] += 1 | |
| return self.api_keys[api_key]["user"] | |
| return None | |
| def check_rate_limit(self, api_key: str) -> bool: | |
| """Check if API key has exceeded rate limit""" | |
| current_minute = datetime.now().strftime("%Y-%m-%d-%H-%M") | |
| if api_key not in self.rate_limits: | |
| self.rate_limits[api_key] = {} | |
| # Clean old entries (keep only last 5 minutes) | |
| cutoff_time = datetime.now() - timedelta(minutes=5) | |
| keys_to_remove = [] | |
| for minute_key in self.rate_limits[api_key]: | |
| try: | |
| minute_time = datetime.strptime(minute_key, "%Y-%m-%d-%H-%M") | |
| if minute_time < cutoff_time: | |
| keys_to_remove.append(minute_key) | |
| except ValueError: | |
| keys_to_remove.append(minute_key) | |
| for key in keys_to_remove: | |
| del self.rate_limits[api_key][key] | |
| # Check current minute | |
| current_count = self.rate_limits[api_key].get(current_minute, 0) | |
| if current_count >= self.max_requests_per_minute: | |
| return False | |
| # Increment counter | |
| self.rate_limits[api_key][current_minute] = current_count + 1 | |
| return True | |
| def get_api_key_stats(self, api_key: str) -> Optional[Dict]: | |
| """Get statistics for an API key""" | |
| if api_key in self.api_keys: | |
| stats = self.api_keys[api_key].copy() | |
| current_minute = datetime.now().strftime("%Y-%m-%d-%H-%M") | |
| stats["current_minute_requests"] = self.rate_limits.get(api_key, {}).get(current_minute, 0) | |
| stats["rate_limit"] = self.max_requests_per_minute | |
| return stats | |
| return None | |
| def generate_new_api_key(self, user: str) -> str: | |
| """Generate a new secure API key""" | |
| api_key = secrets.token_urlsafe(32) | |
| self.api_keys[api_key] = { | |
| "user": user, | |
| "created": datetime.now(), | |
| "last_used": None, | |
| "request_count": 0 | |
| } | |
| return api_key | |
| def revoke_api_key(self, api_key: str) -> bool: | |
| """Revoke an API key""" | |
| if api_key in self.api_keys: | |
| del self.api_keys[api_key] | |
| if api_key in self.rate_limits: | |
| del self.rate_limits[api_key] | |
| return True | |
| return False | |
| def list_api_keys(self) -> Dict: | |
| """List all API keys with their stats (without revealing the keys)""" | |
| result = {} | |
| for api_key, info in self.api_keys.items(): | |
| # Hash the API key for identification without revealing it | |
| key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:8] | |
| result[key_hash] = { | |
| "user": info["user"], | |
| "created": info["created"].isoformat(), | |
| "last_used": info["last_used"].isoformat() if info["last_used"] else None, | |
| "request_count": info["request_count"] | |
| } | |
| return result | |
| # Global instance | |
| api_key_manager = APIKeyManager() | |
| def get_api_key_manager() -> APIKeyManager: | |
| """Get the global API key manager instance""" | |
| return api_key_manager | |