Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from typing import Optional | |
| from datetime import datetime | |
| from fastapi import FastAPI, HTTPException, Depends, Security, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="LLM AI Agent API", | |
| description="Secure AI Agent API with Local LLM deployment", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Security | |
| security = HTTPBearer() | |
| # Configuration | |
| API_KEYS = { | |
| os.getenv("API_KEY_1", "27Eud5J73j6SqPQAT2ioV-CtiCg-p0WNqq6I4U0Ig6E"): "user1", | |
| os.getenv("API_KEY_2", "QbzG2CqHU1Nn6F1EogZ1d3dp8ilRTMJQBwTJDQBzS-U"): "user2", | |
| } | |
| # Global variables for model | |
| model = None | |
| tokenizer = None | |
| model_loaded = False | |
| # Request/Response models | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=1000) | |
| max_length: Optional[int] = Field(100, ge=10, le=500) | |
| temperature: Optional[float] = Field(0.7, ge=0.1, le=2.0) | |
| class ChatResponse(BaseModel): | |
| response: str | |
| model_used: str | |
| timestamp: str | |
| processing_time: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| timestamp: str | |
| def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: | |
| """Verify API key authentication""" | |
| api_key = credentials.credentials | |
| if api_key not in API_KEYS: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid API key" | |
| ) | |
| return API_KEYS[api_key] | |
| async def load_model(): | |
| """Load the LLM model on startup""" | |
| global model, tokenizer, model_loaded | |
| try: | |
| logger.info("Loading model...") | |
| # Try to import and load transformers | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| model_name = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small") | |
| logger.info(f"Loading model: {model_name}") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, # Use float32 for compatibility | |
| low_cpu_mem_usage=True | |
| ) | |
| model_loaded = True | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.warning(f"Could not load transformers model: {e}") | |
| logger.info("Running in demo mode with simple responses") | |
| model_loaded = False | |
| except Exception as e: | |
| logger.error(f"Error during startup: {str(e)}") | |
| model_loaded = False | |
| async def root(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| model_loaded=model_loaded, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| async def health_check(): | |
| """Detailed health check""" | |
| return HealthResponse( | |
| status="healthy" if model_loaded else "demo_mode", | |
| model_loaded=model_loaded, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| async def chat( | |
| request: ChatRequest, | |
| user: str = Depends(verify_api_key) | |
| ): | |
| """Main chat endpoint for AI agent interaction""" | |
| start_time = datetime.now() | |
| try: | |
| if model_loaded and model is not None and tokenizer is not None: | |
| # Use actual model | |
| from transformers import pipeline | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=-1 # Use CPU | |
| ) | |
| # Generate response | |
| generated = generator( | |
| request.message, | |
| max_length=request.max_length, | |
| temperature=request.temperature, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| num_return_sequences=1 | |
| ) | |
| response_text = generated[0]['generated_text'] | |
| if request.message in response_text: | |
| response_text = response_text.replace(request.message, "").strip() | |
| model_used = os.getenv("MODEL_NAME", "microsoft/DialoGPT-small") | |
| else: | |
| # Demo mode - simple responses | |
| demo_responses = { | |
| "hello": "Hello! I'm your AI assistant. How can I help you today?", | |
| "hi": "Hi there! I'm ready to assist you.", | |
| "how are you": "I'm doing well, thank you for asking! How can I help you?", | |
| "what is ai": "AI (Artificial Intelligence) is the simulation of human intelligence in machines that are programmed to think and learn.", | |
| "machine learning": "Machine learning is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed.", | |
| "default": "I'm an AI assistant ready to help you. Could you please rephrase your question?" | |
| } | |
| message_lower = request.message.lower() | |
| response_text = demo_responses.get("default", "I'm here to help!") | |
| for key, response in demo_responses.items(): | |
| if key in message_lower: | |
| response_text = response | |
| break | |
| model_used = "demo_mode" | |
| # Calculate processing time | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| return ChatResponse( | |
| response=response_text, | |
| model_used=model_used, | |
| timestamp=datetime.now().isoformat(), | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error generating response: {str(e)}" | |
| ) | |
| async def get_model_info(user: str = Depends(verify_api_key)): | |
| """Get information about the loaded model""" | |
| return { | |
| "model_name": os.getenv("MODEL_NAME", "microsoft/DialoGPT-small"), | |
| "model_loaded": model_loaded, | |
| "status": "loaded" if model_loaded else "demo_mode" | |
| } | |
| if __name__ == "__main__": | |
| # For local development and Hugging Face Spaces | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run( | |
| "app_simple:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=False | |
| ) | |