Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI,Request,status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from src.app.routes import inference | |
| from src.app.exceptions import ModelLoadError, PreprocessingError, InferenceError,InputError, PostprocessingError | |
| import torch | |
| import os | |
| import sys | |
| app = FastAPI(title="Super Resolution Dental X-ray API", version="1.0.0") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"]) | |
| # Include API routes | |
| app.include_router(inference.router, prefix="/inference", tags=["Inference"]) | |
| def read_root(): | |
| return {"message": "Welcome to the Super Resolution Dental X-ray API"} | |
| async def health_check(): | |
| """ | |
| Health check endpoint to ensure the API and CUDA are running. | |
| Returns: | |
| dict: Status message indicating the API and CUDA availability. | |
| """ | |
| def bash(command): | |
| return os.popen(command).read() | |
| # Check CUDA status | |
| # Construct response | |
| return { | |
| "status": "Healthy", | |
| "message": "API is running successfully.", | |
| "cuda": { | |
| "sys.version": sys.version, | |
| "torch.__version__": torch.__version__, | |
| "torch.cuda.is_available()": torch.cuda.is_available(), | |
| "torch.version.cuda": torch.version.cuda, | |
| "torch.backends.cudnn.version()": torch.backends.cudnn.version(), | |
| "torch.backends.cudnn.enabled": torch.backends.cudnn.enabled, | |
| "nvidia-smi": bash('nvidia-smi') | |
| } | |
| } | |
| # Custom exception handlers | |
| async def model_load_error_handler(request: Request, exc: ModelLoadError): | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={"error": "ModelLoadError", "message": exc.message}, | |
| ) | |
| async def preprocessing_error_handler(request: Request, exc: PreprocessingError): | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={"error": "PreprocessingError", "message": exc.message}, | |
| ) | |
| async def postprocessing_error_handler(request: Request, exc: PostprocessingError): | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={"error": "PostprocessingError", "message": exc.message}, | |
| ) | |
| async def inference_error_handler(request: Request, exc: InferenceError): | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={"error": "InferenceError", "message": exc.message}, | |
| ) | |
| async def input_load_error_handler(request: Request, exc: InputError): | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={"error": "InputError", "message": exc.message}, | |
| ) |