Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import io | |
| import numpy as np | |
| from typing import List, Dict, Any | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="ChatGPT Oasis Model Inference API", | |
| description="FastAPI inference server for Oasis and ViT models", | |
| version="1.0.0" | |
| ) | |
| # Global variables to store loaded models | |
| oasis_model = None | |
| oasis_processor = None | |
| vit_model = None | |
| vit_processor = None | |
| class InferenceRequest(BaseModel): | |
| image: str # Base64 encoded image | |
| model_name: str = "oasis500m" # Default to oasis model | |
| class InferenceResponse(BaseModel): | |
| predictions: List[Dict[str, Any]] | |
| model_used: str | |
| confidence_scores: List[float] | |
| def load_models(): | |
| """Load both models into memory""" | |
| global oasis_model, oasis_processor, vit_model, vit_processor | |
| try: | |
| logger.info("Loading Oasis 500M model...") | |
| # Load Oasis model | |
| oasis_processor = AutoImageProcessor.from_pretrained("microsoft/oasis-500m") | |
| oasis_model = AutoModelForImageClassification.from_pretrained("microsoft/oasis-500m") | |
| oasis_model.eval() | |
| logger.info("Loading ViT-L-20 model...") | |
| # Load ViT model | |
| vit_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224") | |
| vit_model = AutoModelForImageClassification.from_pretrained("google/vit-large-patch16-224") | |
| vit_model.eval() | |
| logger.info("All models loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error loading models: {e}") | |
| raise e | |
| async def startup_event(): | |
| """Load models when the application starts""" | |
| load_models() | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "ChatGPT Oasis Model Inference API", | |
| "version": "1.0.0", | |
| "available_models": ["oasis500m", "vit-l-20"], | |
| "endpoints": { | |
| "health": "/health", | |
| "inference": "/inference", | |
| "upload_inference": "/upload_inference" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| models_status = { | |
| "oasis500m": oasis_model is not None, | |
| "vit-l-20": vit_model is not None | |
| } | |
| return { | |
| "status": "healthy", | |
| "models_loaded": models_status | |
| } | |
| def process_image_with_model(image: Image.Image, model_name: str): | |
| """Process image with the specified model""" | |
| if model_name == "oasis500m": | |
| if oasis_model is None or oasis_processor is None: | |
| raise HTTPException(status_code=500, detail="Oasis model not loaded") | |
| inputs = oasis_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = oasis_model(**inputs) | |
| logits = outputs.logits | |
| probabilities = F.softmax(logits, dim=-1) | |
| # Get top predictions | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| predictions = [] | |
| for i in range(top_indices.shape[1]): | |
| pred = { | |
| "label": oasis_model.config.id2label[top_indices[0][i].item()], | |
| "confidence": top_probs[0][i].item() | |
| } | |
| predictions.append(pred) | |
| return predictions | |
| elif model_name == "vit-l-20": | |
| if vit_model is None or vit_processor is None: | |
| raise HTTPException(status_code=500, detail="ViT model not loaded") | |
| inputs = vit_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = vit_model(**inputs) | |
| logits = outputs.logits | |
| probabilities = F.softmax(logits, dim=-1) | |
| # Get top predictions | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| predictions = [] | |
| for i in range(top_indices.shape[1]): | |
| pred = { | |
| "label": vit_model.config.id2label[top_indices[0][i].item()], | |
| "confidence": top_probs[0][i].item() | |
| } | |
| predictions.append(pred) | |
| return predictions | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}") | |
| async def inference(request: InferenceRequest): | |
| """Inference endpoint using base64 encoded image""" | |
| try: | |
| import base64 | |
| # Decode base64 image | |
| image_data = base64.b64decode(request.image) | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| # Process with model | |
| predictions = process_image_with_model(image, request.model_name) | |
| # Extract confidence scores | |
| confidence_scores = [pred["confidence"] for pred in predictions] | |
| return InferenceResponse( | |
| predictions=predictions, | |
| model_used=request.model_name, | |
| confidence_scores=confidence_scores | |
| ) | |
| except Exception as e: | |
| logger.error(f"Inference error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_inference( | |
| file: UploadFile = File(...), | |
| model_name: str = "oasis500m" | |
| ): | |
| """Inference endpoint using file upload""" | |
| try: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read and process image | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| # Process with model | |
| predictions = process_image_with_model(image, model_name) | |
| # Extract confidence scores | |
| confidence_scores = [pred["confidence"] for pred in predictions] | |
| return InferenceResponse( | |
| predictions=predictions, | |
| model_used=model_name, | |
| confidence_scores=confidence_scores | |
| ) | |
| except Exception as e: | |
| logger.error(f"Upload inference error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_models(): | |
| """List available models and their status""" | |
| return { | |
| "available_models": [ | |
| { | |
| "name": "oasis500m", | |
| "description": "Oasis 500M vision model", | |
| "loaded": oasis_model is not None | |
| }, | |
| { | |
| "name": "vit-l-20", | |
| "description": "Vision Transformer Large model", | |
| "loaded": vit_model is not None | |
| } | |
| ] | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |