Spaces:
Sleeping
Sleeping
| import os | |
| import secrets | |
| import hashlib | |
| from typing import Optional, Dict, Any | |
| from datetime import datetime, timedelta | |
| import logging | |
| from fastapi import FastAPI, HTTPException, Depends, Security, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import uvicorn | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="LLM AI Agent API", | |
| description="Secure AI Agent API with Local LLM deployment", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # CORS middleware for cross-origin requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure this for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Security | |
| security = HTTPBearer() | |
| # Configuration | |
| class Config: | |
| # API Keys - In production, use environment variables | |
| API_KEYS = { | |
| os.getenv("API_KEY_1", "your-secure-api-key-1"): "user1", | |
| os.getenv("API_KEY_2", "your-secure-api-key-2"): "user2", | |
| # Add more API keys as needed | |
| } | |
| # Model configuration | |
| MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") # Lightweight model for free tier | |
| MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512")) | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) | |
| TOP_P = float(os.getenv("TOP_P", "0.9")) | |
| # Rate limiting (requests per minute per API key) | |
| RATE_LIMIT = int(os.getenv("RATE_LIMIT", "10")) | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| text_generator = None | |
| # Request/Response models | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=1000, description="Input message for the AI agent") | |
| max_length: Optional[int] = Field(None, ge=10, le=2048, description="Maximum response length") | |
| temperature: Optional[float] = Field(None, ge=0.1, le=2.0, description="Response creativity (0.1-2.0)") | |
| system_prompt: Optional[str] = Field(None, max_length=500, description="Optional system prompt") | |
| class ChatResponse(BaseModel): | |
| response: str | |
| model_used: str | |
| timestamp: str | |
| tokens_used: int | |
| processing_time: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| timestamp: str | |
| version: str | |
| # Rate limiting storage (in production, use Redis) | |
| request_counts: Dict[str, Dict[str, int]] = {} | |
| def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: | |
| """Verify API key authentication""" | |
| api_key = credentials.credentials | |
| if api_key not in Config.API_KEYS: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid API key", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return Config.API_KEYS[api_key] | |
| def check_rate_limit(api_key: str) -> bool: | |
| """Simple rate limiting implementation""" | |
| current_minute = datetime.now().strftime("%Y-%m-%d-%H-%M") | |
| if api_key not in request_counts: | |
| request_counts[api_key] = {} | |
| if current_minute not in request_counts[api_key]: | |
| request_counts[api_key][current_minute] = 0 | |
| if request_counts[api_key][current_minute] >= Config.RATE_LIMIT: | |
| return False | |
| request_counts[api_key][current_minute] += 1 | |
| return True | |
| async def load_model(): | |
| """Load the LLM model on startup""" | |
| global model, tokenizer, text_generator | |
| try: | |
| logger.info(f"Loading model: {Config.MODEL_NAME}") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME) | |
| # Add padding token if it doesn't exist | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model with optimizations for free tier | |
| model = AutoModelForCausalLM.from_pretrained( | |
| Config.MODEL_NAME, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Create text generation pipeline | |
| text_generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise e | |
| async def root(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| model_loaded=model is not None, | |
| timestamp=datetime.now().isoformat(), | |
| version="1.0.0" | |
| ) | |
| async def health_check(): | |
| """Detailed health check""" | |
| return HealthResponse( | |
| status="healthy" if model is not None else "model_not_loaded", | |
| model_loaded=model is not None, | |
| timestamp=datetime.now().isoformat(), | |
| version="1.0.0" | |
| ) | |
| async def chat( | |
| request: ChatRequest, | |
| user: str = Depends(verify_api_key) | |
| ): | |
| """Main chat endpoint for AI agent interaction""" | |
| start_time = datetime.now() | |
| # Check rate limiting | |
| api_key = None # In a real implementation, you'd extract this from the token | |
| # if not check_rate_limit(api_key): | |
| # raise HTTPException( | |
| # status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| # detail="Rate limit exceeded. Please try again later." | |
| # ) | |
| if model is None or tokenizer is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Model not loaded. Please try again later." | |
| ) | |
| try: | |
| # Prepare input | |
| input_text = request.message | |
| if request.system_prompt: | |
| input_text = f"System: {request.system_prompt}\nUser: {request.message}\nAssistant:" | |
| # Generate response | |
| max_length = request.max_length or Config.MAX_LENGTH | |
| temperature = request.temperature or Config.TEMPERATURE | |
| # Generate text | |
| generated = text_generator( | |
| input_text, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=Config.TOP_P, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| num_return_sequences=1, | |
| truncation=True | |
| ) | |
| # Extract response | |
| response_text = generated[0]['generated_text'] | |
| if input_text in response_text: | |
| response_text = response_text.replace(input_text, "").strip() | |
| # Calculate processing time | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| # Count tokens (approximate) | |
| tokens_used = len(tokenizer.encode(response_text)) | |
| return ChatResponse( | |
| response=response_text, | |
| model_used=Config.MODEL_NAME, | |
| timestamp=datetime.now().isoformat(), | |
| tokens_used=tokens_used, | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error generating response: {str(e)}" | |
| ) | |
| async def get_model_info(user: str = Depends(verify_api_key)): | |
| """Get information about the loaded model""" | |
| return { | |
| "model_name": Config.MODEL_NAME, | |
| "model_loaded": model is not None, | |
| "max_length": Config.MAX_LENGTH, | |
| "temperature": Config.TEMPERATURE, | |
| "device": "cuda" if torch.cuda.is_available() else "cpu" | |
| } | |
| if __name__ == "__main__": | |
| # For local development | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=int(os.getenv("PORT", "7860")), # Hugging Face Spaces uses port 7860 | |
| reload=False | |
| ) | |