from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel import torch import torch.nn.functional as F from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import io import numpy as np from typing import List, Dict, Any import logging import os # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="ChatGPT Oasis Model Inference API", description="FastAPI inference server for Oasis and ViT models deployed on Hugging Face Spaces with Docker", version="1.0.0" ) # Global variables to store loaded models oasis_model = None oasis_processor = None vit_model = None vit_processor = None class InferenceRequest(BaseModel): image: str # Base64 encoded image model_name: str = "oasis500m" # Default to oasis model class InferenceResponse(BaseModel): predictions: List[Dict[str, Any]] model_used: str confidence_scores: List[float] def load_models(): """Load both models from local files""" global oasis_model, oasis_processor, vit_model, vit_processor try: logger.info("Loading Oasis 500M model from local files...") # Load Oasis model from local files oasis_processor = AutoImageProcessor.from_pretrained("microsoft/oasis-500m") oasis_model = AutoModelForImageClassification.from_pretrained( "microsoft/oasis-500m", local_files_only=False # Will download config but use local weights ) # Load local weights if available oasis_model_path = "/app/models/oasis500m.safetensors" if os.path.exists(oasis_model_path): logger.info("Loading Oasis weights from local file...") from safetensors.torch import load_file state_dict = load_file(oasis_model_path) oasis_model.load_state_dict(state_dict, strict=False) oasis_model.eval() logger.info("Loading ViT-L-20 model from local files...") # Load ViT model from local files vit_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224") vit_model = AutoModelForImageClassification.from_pretrained( "google/vit-large-patch16-224", local_files_only=False # Will download config but use local weights ) # Load local weights if available vit_model_path = "/app/models/vit-l-20.safetensors" if os.path.exists(vit_model_path): logger.info("Loading ViT weights from local file...") from safetensors.torch import load_file state_dict = load_file(vit_model_path) vit_model.load_state_dict(state_dict, strict=False) vit_model.eval() logger.info("All models loaded successfully!") except Exception as e: logger.error(f"Error loading models: {e}") raise e @app.on_event("startup") async def startup_event(): """Load models when the application starts""" load_models() @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "ChatGPT Oasis Model Inference API", "version": "1.0.0", "deployed_on": "Hugging Face Spaces (Docker)", "available_models": ["oasis500m", "vit-l-20"], "endpoints": { "health": "/health", "inference": "/inference", "upload_inference": "/upload_inference", "predict": "/predict" }, "usage": { "base64_inference": "POST /inference with JSON body containing 'image' (base64) and 'model_name'", "file_upload": "POST /upload_inference with multipart form containing 'file' and optional 'model_name'", "simple_predict": "POST /predict with file upload for quick inference" } } @app.get("/health") async def health_check(): """Health check endpoint""" models_status = { "oasis500m": oasis_model is not None, "vit-l-20": vit_model is not None } # Check if model files exist model_files = { "oasis500m": os.path.exists("/app/models/oasis500m.safetensors"), "vit-l-20": os.path.exists("/app/models/vit-l-20.safetensors") } return { "status": "healthy", "models_loaded": models_status, "model_files_present": model_files, "deployment": "huggingface-spaces-docker" } def process_image_with_model(image: Image.Image, model_name: str): """Process image with the specified model""" if model_name == "oasis500m": if oasis_model is None or oasis_processor is None: raise HTTPException(status_code=500, detail="Oasis model not loaded") inputs = oasis_processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = oasis_model(**inputs) logits = outputs.logits probabilities = F.softmax(logits, dim=-1) # Get top predictions top_probs, top_indices = torch.topk(probabilities, 5) predictions = [] for i in range(top_indices.shape[1]): pred = { "label": oasis_model.config.id2label[top_indices[0][i].item()], "confidence": top_probs[0][i].item() } predictions.append(pred) return predictions elif model_name == "vit-l-20": if vit_model is None or vit_processor is None: raise HTTPException(status_code=500, detail="ViT model not loaded") inputs = vit_processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = vit_model(**inputs) logits = outputs.logits probabilities = F.softmax(logits, dim=-1) # Get top predictions top_probs, top_indices = torch.topk(probabilities, 5) predictions = [] for i in range(top_indices.shape[1]): pred = { "label": vit_model.config.id2label[top_indices[0][i].item()], "confidence": top_probs[0][i].item() } predictions.append(pred) return predictions else: raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}") @app.post("/inference", response_model=InferenceResponse) async def inference(request: InferenceRequest): """Inference endpoint using base64 encoded image""" try: import base64 # Decode base64 image image_data = base64.b64decode(request.image) image = Image.open(io.BytesIO(image_data)).convert('RGB') # Process with model predictions = process_image_with_model(image, request.model_name) # Extract confidence scores confidence_scores = [pred["confidence"] for pred in predictions] return InferenceResponse( predictions=predictions, model_used=request.model_name, confidence_scores=confidence_scores ) except Exception as e: logger.error(f"Inference error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload_inference", response_model=InferenceResponse) async def upload_inference( file: UploadFile = File(...), model_name: str = "oasis500m" ): """Inference endpoint using file upload""" try: # Validate file type if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") # Read and process image image_data = await file.read() image = Image.open(io.BytesIO(image_data)).convert('RGB') # Process with model predictions = process_image_with_model(image, model_name) # Extract confidence scores confidence_scores = [pred["confidence"] for pred in predictions] return InferenceResponse( predictions=predictions, model_used=model_name, confidence_scores=confidence_scores ) except Exception as e: logger.error(f"Upload inference error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/models") async def list_models(): """List available models and their status""" return { "available_models": [ { "name": "oasis500m", "description": "Oasis 500M vision model", "loaded": oasis_model is not None, "file_present": os.path.exists("/app/models/oasis500m.safetensors") }, { "name": "vit-l-20", "description": "Vision Transformer Large model", "loaded": vit_model is not None, "file_present": os.path.exists("/app/models/vit-l-20.safetensors") } ] } # Hugging Face Spaces specific endpoint for Gradio compatibility @app.post("/predict") async def predict(file: UploadFile = File(...)): """Simple prediction endpoint for Hugging Face Spaces integration""" try: # Validate file type if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") # Read and process image image_data = await file.read() image = Image.open(io.BytesIO(image_data)).convert('RGB') # Process with default model (oasis500m) predictions = process_image_with_model(image, "oasis500m") # Return simplified format for Gradio return { "predictions": predictions[:3], # Top 3 predictions "model_used": "oasis500m" } except Exception as e: logger.error(f"Predict error: {e}") raise HTTPException(status_code=500, detail=str(e))