rainmoon4546's picture
FIX: Corrected case-sensitive path for OVR_Checkpoints folder
d077885
import torch
import cv2
import numpy as np
import joblib
import os
import uvicorn
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from io import BytesIO
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from ultralytics import YOLO
from sklearn.preprocessing import MultiLabelBinarizer
import asyncio
# ======================================================
# 0. KONFIGURASI DAN MUAT MODEL (PATH HARUS RELATIF)
# ======================================================
# Catatan: Hugging Face Spaces akan mencari file di direktori ini (BASE_PATH = .)
BASE_PATH = "."
YOLO_MODEL_PATH = os.path.join(BASE_PATH, "best.pt")
# Hapus penamaan tanggal yang panjang di Colab dan gunakan nama folder sebenarnya:
MODEL_DIR = os.path.join(BASE_PATH, "OVR_Checkpoints-20251018T053026Z-1-001", "OVR_Checkpoints")
ENCODER_PATH = os.path.join(BASE_PATH, "dinov3_multilabel_encoder.pkl")
MAPPING_SAVE_PATH = os.path.join(BASE_PATH, "label_mapping_dict.joblib")
label_columns = ["product", "grade", "cap", "label", "brand", "type", "subtype", "volume"]
device = torch.device("cpu") # Gunakan CPU untuk stabilitas di HF Spaces (kecuali Anda memilih GPU di settings)
# --- SETUP DINOv3 Auth Token ---
# Membaca token dari Environment Variable yang Anda set di Secrets HF Space
HF_AUTH_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
# --- FUNGSI EKSTRAKSI DINOv3 (Harus didefinisikan sebelum loading) ---
def extract_dinov3_features(image_crop):
inputs = dinov3_processor(images=image_crop, return_tensors="pt").to(device)
with torch.no_grad():
outputs = dinov3_model(**inputs)
cls_token_features = outputs.last_hidden_state[:, 0, :]
return cls_token_features.cpu().numpy().flatten()
# === LOAD SEMUA MODEL KE MEMORI ===
try:
# 1. Load YOLO
yolo_model = YOLO(YOLO_MODEL_PATH)
# 2. Load DINOv3 Processor dan Model (Wajib menggunakan token)
dinov3_model_name = "facebook/dinov3-convnext-small-pretrain-lvd1689m"
dinov3_processor = AutoImageProcessor.from_pretrained(dinov3_model_name, token=HF_AUTH_TOKEN)
dinov3_model = AutoModel.from_pretrained(dinov3_model_name, token=HF_AUTH_TOKEN).to(device).eval()
# 3. Load Scikit-learn Assets
mlb = joblib.load(ENCODER_PATH)
mapping_dict = joblib.load(MAPPING_SAVE_PATH)
num_classes = len(mlb.classes_)
all_classifiers = []
for i in range(num_classes):
class_name = mlb.classes_[i]
safe_class_name = str(class_name).replace(' ', '_').replace('/', '_').replace(':', '_').replace('.', '_')
classifier_path = os.path.join(MODEL_DIR, f"clf_{i}_{safe_class_name}.pkl")
if not os.path.exists(classifier_path):
classifier_path = os.path.join(MODEL_DIR, f"clf_{i}_{class_name}.pkl")
all_classifiers.append(joblib.load(classifier_path))
print("βœ… Semua model dan aset dimuat ke memori.")
except Exception as e:
# Jika gagal memuat, cetak error dan biarkan API crash (sesuai standar deployment)
print(f"❌ GAGAL MEMUAT MODEL SECARA LOKAL: {e}")
# Anda mungkin perlu menyesuaikan error handling untuk deployment
exit(1)
# ======================================================
# 1. INISIASI API & MIDDLEWARE
# ======================================================
app = FastAPI(title="HF Product Classifier API")
# Middleware CORS (Wajib untuk Lovable.dev)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ======================================================
# 2. ENDPOINT PREDIKSI UTAMA (LOGIKA PENUH)
# ======================================================
@app.post("/predict")
async def predict_product(file: UploadFile = File(...)):
"""Menerima gambar, menjalankan pipeline, dan mengembalikan JSON."""
try:
# 1. Persiapan Gambar
contents = await file.read()
img_pil = Image.open(BytesIO(contents)).convert("RGB")
# 2. Deteksi YOLOv10m
results = yolo_model(img_pil, verbose=False)
if not results or not results[0].boxes:
return JSONResponse(status_code=200, content={"status": "failed", "message": "No object detected."})
# Ambil BBOX terbaik, Crop, Ekstraksi DINOv3
best_box = results[0].boxes.cpu().numpy()[np.argmax(results[0].boxes.cpu().numpy().conf)]
x_min, y_min, x_max, y_max = map(int, best_box.xyxy[0])
confidence_score = float(best_box.conf[0]) # Confidence YOLO
image_crop = img_pil.crop((x_min, y_min, x_max, y_max))
features = extract_dinov3_features(image_crop)
X_pred = features.reshape(1, -1)
# 3. Klasifikasi (Probabilitas)
Y_proba_list = [clf.predict_proba(X_pred)[0, 1] for clf in all_classifiers]
Y_proba = np.array(Y_proba_list)
Y_pred_biner = (Y_proba > 0.5).astype(int).reshape(1, -1)
class_proba_map = dict(zip(mlb.classes_, Y_proba))
predicted_labels_set = set(mlb.inverse_transform(Y_pred_biner)[0])
final_output = {}
# 4. Logika Pemetaan & Fallback
for col in label_columns:
intersection = predicted_labels_set.intersection(mapping_dict[col])
best_proba = -1
best_label = "UNKNOWN"
for label in mapping_dict[col]:
proba = class_proba_map.get(label, 0.0)
if proba > best_proba:
best_proba = proba
best_label = label
if len(intersection) == 1:
final_output[col] = intersection.pop()
elif len(intersection) == 0:
# UNKNOWN: Fallback ke proba tertinggi jika > 20%
if best_proba > 0.20:
final_output[col] = f"{best_label} ({best_proba*100:.1f}%)"
else:
final_output[col] = "UNKNOWN"
else:
# KONFLIK: Pilih label proba tertinggi
final_output[col] = f"CONFLICT -> {best_label} ({best_proba*100:.1f}%)"
# 5. Kembalikan Respons JSON Penuh
return JSONResponse(status_code=200, content={
"status": "success",
"confidence_score": f"{confidence_score*100:.2f}%",
"prediction": final_output # Semua 8 atribut terstruktur
})
except Exception as e:
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
# ======================================================
# 3. RUN SERVER (ENTRY POINT UNTUK HUGGING FACE)
# ======================================================
# Hugging Face Spaces secara otomatis menjalankan Gunicorn/Uvicorn yang menunjuk ke 'app'.
# Di lokal, Anda bisa mengujinya dengan baris ini:
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)