Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import io | |
| from fastapi import FastAPI, Request, Header, HTTPException, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| from PIL import Image | |
| from smebuilder_vector import query_vector | |
| # ============================== | |
| # Logging Setup | |
| # ============================== | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("AgriCopilot") | |
| # ============================== | |
| # App Initialization | |
| # ============================== | |
| app = FastAPI(title="AgriCopilot AI API", version="2.0") | |
| async def root(): | |
| return {"status": "AgriCopilot AI Backend is running smoothly ✅"} | |
| # ============================== | |
| # AUTH CONFIGURATION | |
| # ============================== | |
| PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "agricopilot404") | |
| def check_auth(authorization: str | None): | |
| if not PROJECT_API_KEY: | |
| return | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Missing bearer token") | |
| token = authorization.split(" ", 1)[1] | |
| if token != PROJECT_API_KEY: | |
| raise HTTPException(status_code=403, detail="Invalid token") | |
| # ============================== | |
| # Exception Handling | |
| # ============================== | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| logger.error(f"Unhandled error: {exc}") | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |
| # ============================== | |
| # Request Models | |
| # ============================== | |
| class ChatRequest(BaseModel): | |
| query: str | |
| class DisasterRequest(BaseModel): | |
| report: str | |
| class MarketRequest(BaseModel): | |
| product: str | |
| class VectorRequest(BaseModel): | |
| query: str | |
| # ============================== | |
| # Load Hugging Face Pipelines | |
| # ============================== | |
| HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not HF_TOKEN: | |
| logger.warning("⚠️ No Hugging Face token found. Gated models may fail.") | |
| else: | |
| logger.info("✅ Hugging Face token loaded successfully.") | |
| # General text-generation model for chat, disaster, market endpoints | |
| default_model = "meta-llama/Llama-3.1-8B-Instruct" | |
| vision_model = "meta-llama/Llama-3.2-11B-Vision-Instruct" | |
| chat_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN) | |
| disaster_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN) | |
| market_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN) | |
| # Multimodal crop diagnostic model | |
| try: | |
| crop_pipe = pipeline("image-text-to-text", model=vision_model, token=HF_TOKEN) | |
| except Exception as e: | |
| logger.warning(f"Crop model load failed: {e}") | |
| crop_pipe = None | |
| # ============================== | |
| # Helper Functions | |
| # ============================== | |
| def run_conversational(pipe, prompt: str): | |
| try: | |
| output = pipe(prompt, max_new_tokens=200) | |
| if isinstance(output, list) and len(output) > 0: | |
| return output[0].get("generated_text", str(output)) | |
| return str(output) | |
| except Exception as e: | |
| logger.error(f"Pipeline error: {e}") | |
| return f"⚠️ Model error: {str(e)}" | |
| def run_crop_doctor(image_bytes: bytes, symptoms: str): | |
| """ | |
| Diagnose crop issues using Meta's multimodal LLaMA Vision model. | |
| """ | |
| if not crop_pipe: | |
| return "⚠️ Crop analysis temporarily unavailable (model not loaded)." | |
| try: | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| prompt = ( | |
| f"The farmer reports: {symptoms}. " | |
| "Analyze the image and diagnose the likely crop disease. " | |
| "Then explain it simply and recommend possible treatment steps." | |
| ) | |
| output = crop_pipe(image, prompt) | |
| if isinstance(output, list) and len(output) > 0: | |
| return output[0].get("generated_text", str(output)) | |
| return str(output) | |
| except Exception as e: | |
| logger.error(f"Crop Doctor pipeline error: {e}") | |
| return f"⚠️ Unexpected model error: {str(e)}" | |
| # ============================== | |
| # API ROUTES | |
| # ============================== | |
| async def multilingual_chat(req: ChatRequest, authorization: str | None = Header(None)): | |
| check_auth(authorization) | |
| reply = run_conversational(chat_pipe, req.query) | |
| return {"reply": reply} | |
| async def disaster_summarizer(req: DisasterRequest, authorization: str | None = Header(None)): | |
| check_auth(authorization) | |
| summary = run_conversational(disaster_pipe, req.report) | |
| return {"summary": summary} | |
| async def marketplace(req: MarketRequest, authorization: str | None = Header(None)): | |
| check_auth(authorization) | |
| recommendation = run_conversational(market_pipe, req.product) | |
| return {"recommendation": recommendation} | |
| async def vector_search(req: VectorRequest, authorization: str | None = Header(None)): | |
| check_auth(authorization) | |
| try: | |
| results = query_vector(req.query) | |
| return {"results": results} | |
| except Exception as e: | |
| logger.error(f"Vector search error: {e}") | |
| return {"error": f"Vector search error: {str(e)}"} | |
| async def crop_doctor( | |
| symptoms: str = Header(...), | |
| image: UploadFile = File(...), | |
| authorization: str | None = Header(None) | |
| ): | |
| check_auth(authorization) | |
| image_bytes = await image.read() | |
| diagnosis = run_crop_doctor(image_bytes, symptoms) | |
| return {"diagnosis": diagnosis} | |