neurosense_app / app.py
Shani13524's picture
Update app.py
7905879 verified
import os, numpy as np
from typing import List, Tuple, Dict, Any
from PIL import Image
import torch
import torch.nn.functional as F
import gradio as gr
from datasets import load_dataset
from sklearn.neighbors import NearestNeighbors
# =============== CONFIG ===============
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Embeddings backbone
OPENCLIP_BACKBONE = "ViT-H-14"
OPENCLIP_PRETRAIN = "laion2B-s32B-b79K" # laion/CLIP-ViT-H-14-laion2B-s32B-b79K
# Dataset (THIS IS YOUR "MODEL" SOURCE NOW)
DATASET_NAME = "tukey/human_face_emotions_roboflow"
DATASET_SPLIT = "train"
INDEX_SIZE = int(os.getenv("INDEX_SIZE", 400)) # כמה דוגמאות מהדאטהסט לאינדוקס
TOPK_NEAREST = 5 # להצגה בגלריה
KNN_K_FOR_CLASS = 25 # לשקלול רגשות
# Optional SD variations
USE_SD_VARIATIONS = True
SD_MODEL = "lambdalabs/sd-image-variations-diffusers"
# =====================================
# ---------- Load OpenCLIP for image embeddings ----------
try:
import open_clip
_openclip_model, _, _openclip_preprocess = open_clip.create_model_and_transforms(
OPENCLIP_BACKBONE, pretrained=OPENCLIP_PRETRAIN
)
_openclip_model = _openclip_model.to(DEVICE).eval()
except Exception as e:
raise RuntimeError(
f"Failed to load OpenCLIP ({OPENCLIP_BACKBONE} / {OPENCLIP_PRETRAIN}). "
f"Install 'open_clip_torch' and verify CUDA if available. Error: {e}"
)
@torch.inference_mode()
def embed_image(img: Image.Image) -> np.ndarray:
img = img.convert("RGB")
tens = _openclip_preprocess(img).unsqueeze(0).to(DEVICE)
feats = _openclip_model.encode_image(tens)
feats = F.normalize(feats, dim=-1).squeeze(0).detach().cpu().numpy().astype(np.float32)
return feats # shape [D]
# ---------- Labels & stress mapping ----------
EMO_MAP = {
"anger": "anger", "angry": "anger",
"disgust": "disgust",
"fear": "fear",
"happy": "happy", "happiness": "happy",
"neutral": "neutral", "calm": "neutral",
"sad": "sad", "sadness": "sad",
"surprise": "surprise",
"contempt": "contempt",
}
ALLOWED = set(EMO_MAP.values()) # whitelist קשיח
STRESS_WEIGHTS = {
"anger": 0.95, "fear": 0.90, "disgust": 0.70, "sad": 0.80,
"surprise": 0.55, "neutral": 0.30, "contempt": 0.65, "happy": 0.10,
}
def _bucket(p: float) -> str:
return "Low" if p < 33 else ("Medium" if p < 66 else "High")
# ---------- Load dataset & build index ----------
def _extract_label(rec: Dict[str, Any]) -> str:
# התאמה לשדות אפשריים בדאטהסט
if "label" in rec and rec["label"]:
raw = rec["label"]
if isinstance(raw, (list, tuple)): raw = raw[0]
return str(raw).strip().lower()
if "labels" in rec and rec["labels"]:
raw = rec["labels"][0]
return str(raw).strip().lower()
if "qa" in rec and rec["qa"] and isinstance(rec["qa"], list):
qa0 = rec["qa"][0]
if qa0 and "answer" in qa0:
return str(qa0["answer"]).strip().lower()
return ""
def _map_allowed(lbl: str) -> str:
# ממפה לשם סטנדרטי, ומסנן החוצה לא מוכרות
mapped = EMO_MAP.get(lbl, lbl)
return mapped if mapped in ALLOWED else "" # "" => drop
def _load_images_labels_for_index(n: int) -> Tuple[List[Image.Image], List[str]]:
ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
imgs, labels = [], []
n = min(n, len(ds))
for i in range(n):
rec = ds[i]
im = rec.get("image")
if not isinstance(im, Image.Image):
continue
raw_lbl = _extract_label(rec)
mapped = _map_allowed(raw_lbl)
if not mapped:
continue # זורק תוויות לא מותרות/ריקות
imgs.append(im.copy())
labels.append(mapped)
return imgs, labels
def build_index(imgs: List[Image.Image]) -> Tuple[NearestNeighbors, np.ndarray]:
vecs = [embed_image(im) for im in imgs]
X = np.stack(vecs, axis=0)
nn = NearestNeighbors(metric="cosine", n_neighbors=min(max(TOPK_NEAREST, KNN_K_FOR_CLASS), len(imgs)))
nn.fit(X)
return nn, X
print("Loading dataset & building index (first time only)...")
DATASET_IMAGES, DATASET_LABELS = _load_images_labels_for_index(INDEX_SIZE)
if len(DATASET_IMAGES) == 0:
raise RuntimeError("No images with allowed labels were loaded from the dataset.")
NN_MODEL, EMB_MATRIX = build_index(DATASET_IMAGES)
print(f"Index ready with {len(DATASET_IMAGES)} images (labels={sorted(set(DATASET_LABELS))}).")
# ---------- Nearest & KNN-based classification ----------
def nearest5(pil_img: Image.Image) -> List[Tuple[Image.Image, str]]:
q = embed_image(pil_img).reshape(1, -1)
n = min(5, len(DATASET_IMAGES))
dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=n)
out = []
for rank, (dist, idx) in enumerate(zip(dists[0], idxs[0]), start=1):
sim = max(0.0, 1.0 - float(dist)) # cosine distance -> similarity
im = DATASET_IMAGES[int(idx)]
caption = f"#{rank} sim={sim:.3f} idx={int(idx)}"
out.append((im, caption))
return out
def knn_probs(pil_img: Image.Image, k: int = KNN_K_FOR_CLASS) -> Dict[str, float]:
q = embed_image(pil_img).reshape(1, -1)
k = min(k, len(DATASET_IMAGES))
dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=k)
sims = 1.0 - dists[0] # higher is better
sims = np.maximum(sims, 0.0)
votes: Dict[str, float] = {}
for sim, idx in zip(sims, idxs[0]):
lbl = DATASET_LABELS[int(idx)]
if lbl in ALLOWED:
votes[lbl] = votes.get(lbl, 0.0) + float(sim)
Z = sum(votes.values()) or 1.0
return {k: v / Z for k, v in votes.items()}
def emotions_top3(pil_img: Image.Image) -> List[List[Any]]:
probs = knn_probs(pil_img)
items = sorted(probs.items(), key=lambda kv: kv[1], reverse=True)[:3]
table = []
for i, (emo, p) in enumerate(items, start=1):
table.append([i, emo, round(100.0 * p, 2)])
# משלימים אם יש פחות מ-3
seen = {r[1] for r in table}
for fill in ["neutral", "other"]:
if len(table) >= 3: break
if fill in ALLOWED and fill not in seen:
table.append([len(table)+1, fill, 0.0])
return table
def stress_index(pil_img: Image.Image) -> Tuple[str, float]:
probs = knn_probs(pil_img)
raw = sum(probs.get(k, 0.0) * STRESS_WEIGHTS.get(k, 0.5) for k in ALLOWED)
pct = max(0.0, min(100.0, 100.0 * raw))
return f"{pct:.1f}% ({_bucket(pct)})", pct
# ---------- Optional: SD image variations ----------
sd_pipe = None
if USE_SD_VARIATIONS:
try:
from diffusers import StableDiffusionImageVariationPipeline
sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
SD_MODEL, torch_dtype=torch.float32
)
sd_pipe = sd_pipe.to(DEVICE)
except Exception as e:
print(f"[WARN] Could not load {SD_MODEL}. Generation disabled. Error: {e}")
sd_pipe = None
def generate_one_variation(pil_img: Image.Image, steps: int) -> Image.Image:
if sd_pipe is None:
raise gr.Error("Image-variation pipeline is not available on this Space.")
pil_img = pil_img.convert("RGB")
out = sd_pipe(pil_img, guidance_scale=3.0, num_inference_steps=int(steps)).images[0]
return out
# ===================== GRADIO UI =====================
CSS = ".box { border: 1px solid #e5e7eb; border-radius: 12px; padding: 10px; }"
with gr.Blocks(title="Face Emotion & Stress Analyzer — KNN over tukey dataset", css=CSS, fill_height=False) as demo:
gr.Markdown(
"### Face Emotion & Stress Analyzer — **KNN over `tukey/human_face_emotions_roboflow`**\n"
"- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n"
"- Emotion model: **KNN using labels from `tukey/human_face_emotions_roboflow`**\n"
"- Optional SD variations: **lambdalabs/sd-image-variations-diffusers** (1 synthetic only)\n"
"- Right column shows nearest 5 images from the dataset (clickable)."
)
# ---- Row 1: upload + (top3_emotion_original | stress_original) ----
with gr.Row():
with gr.Column(scale=2):
upload_image = gr.Image(label="Upload face image", type="pil")
with gr.Column(scale=1):
top3_emotion_original = gr.Dataframe(
headers=["Rank", "Emotion", "Confidence (%)"],
datatype=["number", "str", "number"],
interactive=False, label="Top-3 emotions (original image)",
value=[]
)
with gr.Column(scale=1):
stress_original = gr.Label(label="Stress index (original)")
gr.Markdown("#### Analyze (no synthetics)")
with gr.Row(equal_height=False):
# ---------- LEFT COLUMN ----------
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("**gen_variations_control** — generate only **one** synthetic")
steps = gr.Slider(8, 40, value=12, step=1, label="Diffusion steps (higher=slower/better)")
gen_btn = gr.Button("Generate 1 synthetic", variant="primary")
picked_synth = gr.Image(label="Synthetic preview")
top3_emotion_synth = gr.Dataframe(
headers=["Rank", "Emotion", "Confidence (%)"],
datatype=["number", "str", "number"],
interactive=False, label="top3_emotion_synth",
value=[]
)
stress_synth = gr.Label(label="stress_synth")
# ---------- RIGHT COLUMN ----------
with gr.Column(scale=1):
nearest_images_5 = gr.Gallery(
label="nearest_images_5 (1-click on 5 examples)",
columns=5, rows=1, height=200, allow_preview=False, show_label=True
)
tops_emotion_nearest = gr.Dataframe(
headers=["Rank", "Emotion", "Confidence (%)"],
datatype=["number", "str", "number"],
interactive=False, label="tops_emotion_nearest_image",
value=[]
)
stress_nearest = gr.Label(label="stress_nearest_image")
# --------- Hidden states ---------
gallery_images_state = gr.State([]) # store PILs
gallery_index_state = gr.State([]) # store dataset indexes (ints)
# ================= Callbacks =================
def on_upload(img: Image.Image):
if img is None:
return gr.update(), gr.update(value=""), [], [], []
# original
t3 = emotions_top3(img)
s_label, _ = stress_index(img)
# nearest gallery
gal = nearest5(img) # list[(PIL, caption)]
gal_imgs = [g[0] for g in gal]
gal_caps = [g[1] for g in gal]
gallery = [(im, cap) for im, cap in zip(gal_imgs, gal_caps)]
return t3, s_label, gallery, gal_imgs, list(range(len(gal_imgs)))
upload_image.change(
fn=on_upload,
inputs=[upload_image],
outputs=[top3_emotion_original, stress_original, nearest_images_5, gallery_images_state, gallery_index_state]
)
def on_gallery_select(evt: gr.SelectData, imgs: List[Image.Image], idxs: List[int]):
if imgs is None or not imgs:
return [], ""
i = int(evt.index) if evt is not None else 0
i = max(0, min(i, len(imgs)-1))
im = imgs[i]
t3 = emotions_top3(im)
s_label, _ = stress_index(im)
return t3, s_label
nearest_images_5.select(
fn=on_gallery_select,
inputs=[gallery_images_state, gallery_index_state],
outputs=[tops_emotion_nearest, stress_nearest]
)
def on_generate(img: Image.Image, steps_val: int):
if img is None:
raise gr.Error("Upload an image first.")
if sd_pipe is None:
raise gr.Error("Synthetic generation is disabled on this Space.")
synth = generate_one_variation(img, steps_val)
t3 = emotions_top3(synth)
s_label, _ = stress_index(synth)
return synth, t3, s_label
gen_btn.click(
fn=on_generate,
inputs=[upload_image, steps],
outputs=[picked_synth, top3_emotion_synth, stress_synth]
)
if __name__ == "__main__":
demo.launch()