DevAssist / main.py
lydiasolomon's picture
Update main.py
f70d966 verified
raw
history blame
5.61 kB
import os
import logging
import io
from fastapi import FastAPI, Request, Header, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import pipeline
from PIL import Image
from smebuilder_vector import query_vector
# ==============================
# Logging Setup
# ==============================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("AgriCopilot")
# ==============================
# App Initialization
# ==============================
app = FastAPI(title="AgriCopilot AI API", version="2.0")
@app.get("/")
async def root():
return {"status": "AgriCopilot AI Backend is running smoothly ✅"}
# ==============================
# AUTH CONFIGURATION
# ==============================
PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "agricopilot404")
def check_auth(authorization: str | None):
if not PROJECT_API_KEY:
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing bearer token")
token = authorization.split(" ", 1)[1]
if token != PROJECT_API_KEY:
raise HTTPException(status_code=403, detail="Invalid token")
# ==============================
# Exception Handling
# ==============================
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled error: {exc}")
return JSONResponse(status_code=500, content={"error": str(exc)})
# ==============================
# Request Models
# ==============================
class ChatRequest(BaseModel):
query: str
class DisasterRequest(BaseModel):
report: str
class MarketRequest(BaseModel):
product: str
class VectorRequest(BaseModel):
query: str
# ==============================
# Load Hugging Face Pipelines
# ==============================
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not HF_TOKEN:
logger.warning("⚠️ No Hugging Face token found. Gated models may fail.")
else:
logger.info("✅ Hugging Face token loaded successfully.")
# General text-generation model for chat, disaster, market endpoints
default_model = "meta-llama/Llama-3.1-8B-Instruct"
vision_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
chat_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
disaster_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
market_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
# Multimodal crop diagnostic model
try:
crop_pipe = pipeline("image-text-to-text", model=vision_model, token=HF_TOKEN)
except Exception as e:
logger.warning(f"Crop model load failed: {e}")
crop_pipe = None
# ==============================
# Helper Functions
# ==============================
def run_conversational(pipe, prompt: str):
try:
output = pipe(prompt, max_new_tokens=200)
if isinstance(output, list) and len(output) > 0:
return output[0].get("generated_text", str(output))
return str(output)
except Exception as e:
logger.error(f"Pipeline error: {e}")
return f"⚠️ Model error: {str(e)}"
def run_crop_doctor(image_bytes: bytes, symptoms: str):
"""
Diagnose crop issues using Meta's multimodal LLaMA Vision model.
"""
if not crop_pipe:
return "⚠️ Crop analysis temporarily unavailable (model not loaded)."
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
prompt = (
f"The farmer reports: {symptoms}. "
"Analyze the image and diagnose the likely crop disease. "
"Then explain it simply and recommend possible treatment steps."
)
output = crop_pipe(image, prompt)
if isinstance(output, list) and len(output) > 0:
return output[0].get("generated_text", str(output))
return str(output)
except Exception as e:
logger.error(f"Crop Doctor pipeline error: {e}")
return f"⚠️ Unexpected model error: {str(e)}"
# ==============================
# API ROUTES
# ==============================
@app.post("/multilingual-chat")
async def multilingual_chat(req: ChatRequest, authorization: str | None = Header(None)):
check_auth(authorization)
reply = run_conversational(chat_pipe, req.query)
return {"reply": reply}
@app.post("/disaster-summarizer")
async def disaster_summarizer(req: DisasterRequest, authorization: str | None = Header(None)):
check_auth(authorization)
summary = run_conversational(disaster_pipe, req.report)
return {"summary": summary}
@app.post("/marketplace")
async def marketplace(req: MarketRequest, authorization: str | None = Header(None)):
check_auth(authorization)
recommendation = run_conversational(market_pipe, req.product)
return {"recommendation": recommendation}
@app.post("/vector-search")
async def vector_search(req: VectorRequest, authorization: str | None = Header(None)):
check_auth(authorization)
try:
results = query_vector(req.query)
return {"results": results}
except Exception as e:
logger.error(f"Vector search error: {e}")
return {"error": f"Vector search error: {str(e)}"}
@app.post("/crop-doctor")
async def crop_doctor(
symptoms: str = Header(...),
image: UploadFile = File(...),
authorization: str | None = Header(None)
):
check_auth(authorization)
image_bytes = await image.read()
diagnosis = run_crop_doctor(image_bytes, symptoms)
return {"diagnosis": diagnosis}