|
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
from PIL import Image |
|
|
import tempfile |
|
|
import os |
|
|
import base64 |
|
|
import cv2 |
|
|
import io |
|
|
import re |
|
|
from together import Together |
|
|
import releaf_ai |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
API_KEY = "1495bcdf0c72ed1e15d0e3e31e4301bd665cb28f2291bcc388164ed745a7aa24" |
|
|
client = Together(api_key=API_KEY) |
|
|
MODEL_NAME = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" |
|
|
|
|
|
SYSTEM_PROMPT = releaf_ai.SYSTEM_PROMPT |
|
|
|
|
|
def encode_image_to_base64(image: Image.Image) -> str: |
|
|
buffered = io.BytesIO() |
|
|
image.save(buffered, format="JPEG") |
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
def extract_score(text: str): |
|
|
match = re.search(r"(?i)Score:\s*(\d+)", text) |
|
|
return int(match.group(1)) if match else None |
|
|
|
|
|
def extract_activity(text: str): |
|
|
match = re.search(r"(?i)Detected Activity:\s*(.+?)\n", text) |
|
|
return match.group(1).strip() if match else "Unknown" |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(file: UploadFile = File(...)): |
|
|
try: |
|
|
if file.content_type.startswith("image"): |
|
|
image = Image.open(io.BytesIO(await file.read())).convert("RGB") |
|
|
|
|
|
elif file.content_type.startswith("video"): |
|
|
temp_path = tempfile.NamedTemporaryFile(delete=False).name |
|
|
with open(temp_path, "wb") as f: |
|
|
f.write(await file.read()) |
|
|
|
|
|
cap = cv2.VideoCapture(temp_path) |
|
|
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
interval = max(total // 9, 1) |
|
|
|
|
|
frames = [] |
|
|
for i in range(9): |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, i * interval) |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
img = Image.fromarray(frame).resize((256, 256)) |
|
|
frames.append(img) |
|
|
cap.release() |
|
|
os.remove(temp_path) |
|
|
|
|
|
w, h = frames[0].size |
|
|
grid = Image.new("RGB", (3 * w, 3 * h)) |
|
|
for idx, frame in enumerate(frames): |
|
|
grid.paste(frame, ((idx % 3) * w, (idx // 3) * h)) |
|
|
image = grid |
|
|
|
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Unsupported file type") |
|
|
|
|
|
b64_img = encode_image_to_base64(image) |
|
|
messages = [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": [ |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}} |
|
|
]} |
|
|
] |
|
|
res = client.chat.completions.create(model=MODEL_NAME, messages=messages) |
|
|
reply = res.choices[0].message.content |
|
|
|
|
|
return JSONResponse({ |
|
|
"points": extract_score(reply), |
|
|
"task": extract_activity(reply), |
|
|
"raw": reply |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |