Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import io | |
| import numpy as np | |
| from typing import List, Dict, Any | |
| import logging | |
| import os | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="ChatGPT Oasis Model Inference API", | |
| description="FastAPI inference server for Oasis and ViT models deployed on Hugging Face Spaces with Docker", | |
| version="1.0.0" | |
| ) | |
| # Global variables to store loaded models | |
| oasis_model = None | |
| oasis_processor = None | |
| vit_model = None | |
| vit_processor = None | |
| class InferenceRequest(BaseModel): | |
| image: str # Base64 encoded image | |
| model_name: str = "oasis500m" # Default to oasis model | |
| class InferenceResponse(BaseModel): | |
| predictions: List[Dict[str, Any]] | |
| model_used: str | |
| confidence_scores: List[float] | |
| def load_models(): | |
| """Load both models from local files""" | |
| global oasis_model, oasis_processor, vit_model, vit_processor | |
| try: | |
| logger.info("Loading Oasis 500M model from local files...") | |
| # Load Oasis model from local files | |
| oasis_processor = AutoImageProcessor.from_pretrained("microsoft/oasis-500m") | |
| oasis_model = AutoModelForImageClassification.from_pretrained( | |
| "microsoft/oasis-500m", | |
| local_files_only=False # Will download config but use local weights | |
| ) | |
| # Load local weights if available | |
| oasis_model_path = "/app/models/oasis500m.safetensors" | |
| if os.path.exists(oasis_model_path): | |
| logger.info("Loading Oasis weights from local file...") | |
| from safetensors.torch import load_file | |
| state_dict = load_file(oasis_model_path) | |
| oasis_model.load_state_dict(state_dict, strict=False) | |
| oasis_model.eval() | |
| logger.info("Loading ViT-L-20 model from local files...") | |
| # Load ViT model from local files | |
| vit_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224") | |
| vit_model = AutoModelForImageClassification.from_pretrained( | |
| "google/vit-large-patch16-224", | |
| local_files_only=False # Will download config but use local weights | |
| ) | |
| # Load local weights if available | |
| vit_model_path = "/app/models/vit-l-20.safetensors" | |
| if os.path.exists(vit_model_path): | |
| logger.info("Loading ViT weights from local file...") | |
| from safetensors.torch import load_file | |
| state_dict = load_file(vit_model_path) | |
| vit_model.load_state_dict(state_dict, strict=False) | |
| vit_model.eval() | |
| logger.info("All models loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error loading models: {e}") | |
| raise e | |
| async def startup_event(): | |
| """Load models when the application starts""" | |
| load_models() | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "ChatGPT Oasis Model Inference API", | |
| "version": "1.0.0", | |
| "deployed_on": "Hugging Face Spaces (Docker)", | |
| "available_models": ["oasis500m", "vit-l-20"], | |
| "endpoints": { | |
| "health": "/health", | |
| "inference": "/inference", | |
| "upload_inference": "/upload_inference", | |
| "predict": "/predict" | |
| }, | |
| "usage": { | |
| "base64_inference": "POST /inference with JSON body containing 'image' (base64) and 'model_name'", | |
| "file_upload": "POST /upload_inference with multipart form containing 'file' and optional 'model_name'", | |
| "simple_predict": "POST /predict with file upload for quick inference" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| models_status = { | |
| "oasis500m": oasis_model is not None, | |
| "vit-l-20": vit_model is not None | |
| } | |
| # Check if model files exist | |
| model_files = { | |
| "oasis500m": os.path.exists("/app/models/oasis500m.safetensors"), | |
| "vit-l-20": os.path.exists("/app/models/vit-l-20.safetensors") | |
| } | |
| return { | |
| "status": "healthy", | |
| "models_loaded": models_status, | |
| "model_files_present": model_files, | |
| "deployment": "huggingface-spaces-docker" | |
| } | |
| def process_image_with_model(image: Image.Image, model_name: str): | |
| """Process image with the specified model""" | |
| if model_name == "oasis500m": | |
| if oasis_model is None or oasis_processor is None: | |
| raise HTTPException(status_code=500, detail="Oasis model not loaded") | |
| inputs = oasis_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = oasis_model(**inputs) | |
| logits = outputs.logits | |
| probabilities = F.softmax(logits, dim=-1) | |
| # Get top predictions | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| predictions = [] | |
| for i in range(top_indices.shape[1]): | |
| pred = { | |
| "label": oasis_model.config.id2label[top_indices[0][i].item()], | |
| "confidence": top_probs[0][i].item() | |
| } | |
| predictions.append(pred) | |
| return predictions | |
| elif model_name == "vit-l-20": | |
| if vit_model is None or vit_processor is None: | |
| raise HTTPException(status_code=500, detail="ViT model not loaded") | |
| inputs = vit_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = vit_model(**inputs) | |
| logits = outputs.logits | |
| probabilities = F.softmax(logits, dim=-1) | |
| # Get top predictions | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| predictions = [] | |
| for i in range(top_indices.shape[1]): | |
| pred = { | |
| "label": vit_model.config.id2label[top_indices[0][i].item()], | |
| "confidence": top_probs[0][i].item() | |
| } | |
| predictions.append(pred) | |
| return predictions | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}") | |
| async def inference(request: InferenceRequest): | |
| """Inference endpoint using base64 encoded image""" | |
| try: | |
| import base64 | |
| # Decode base64 image | |
| image_data = base64.b64decode(request.image) | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| # Process with model | |
| predictions = process_image_with_model(image, request.model_name) | |
| # Extract confidence scores | |
| confidence_scores = [pred["confidence"] for pred in predictions] | |
| return InferenceResponse( | |
| predictions=predictions, | |
| model_used=request.model_name, | |
| confidence_scores=confidence_scores | |
| ) | |
| except Exception as e: | |
| logger.error(f"Inference error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_inference( | |
| file: UploadFile = File(...), | |
| model_name: str = "oasis500m" | |
| ): | |
| """Inference endpoint using file upload""" | |
| try: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read and process image | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| # Process with model | |
| predictions = process_image_with_model(image, model_name) | |
| # Extract confidence scores | |
| confidence_scores = [pred["confidence"] for pred in predictions] | |
| return InferenceResponse( | |
| predictions=predictions, | |
| model_used=model_name, | |
| confidence_scores=confidence_scores | |
| ) | |
| except Exception as e: | |
| logger.error(f"Upload inference error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_models(): | |
| """List available models and their status""" | |
| return { | |
| "available_models": [ | |
| { | |
| "name": "oasis500m", | |
| "description": "Oasis 500M vision model", | |
| "loaded": oasis_model is not None, | |
| "file_present": os.path.exists("/app/models/oasis500m.safetensors") | |
| }, | |
| { | |
| "name": "vit-l-20", | |
| "description": "Vision Transformer Large model", | |
| "loaded": vit_model is not None, | |
| "file_present": os.path.exists("/app/models/vit-l-20.safetensors") | |
| } | |
| ] | |
| } | |
| # Hugging Face Spaces specific endpoint for Gradio compatibility | |
| async def predict(file: UploadFile = File(...)): | |
| """Simple prediction endpoint for Hugging Face Spaces integration""" | |
| try: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read and process image | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| # Process with default model (oasis500m) | |
| predictions = process_image_with_model(image, "oasis500m") | |
| # Return simplified format for Gradio | |
| return { | |
| "predictions": predictions[:3], # Top 3 predictions | |
| "model_used": "oasis500m" | |
| } | |
| except Exception as e: | |
| logger.error(f"Predict error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |