import os, numpy as np from typing import List, Tuple, Dict, Any from PIL import Image import torch import torch.nn.functional as F import gradio as gr from datasets import load_dataset from sklearn.neighbors import NearestNeighbors # =============== CONFIG =============== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Embeddings backbone OPENCLIP_BACKBONE = "ViT-H-14" OPENCLIP_PRETRAIN = "laion2B-s32B-b79K" # laion/CLIP-ViT-H-14-laion2B-s32B-b79K # Dataset (THIS IS YOUR "MODEL" SOURCE NOW) DATASET_NAME = "tukey/human_face_emotions_roboflow" DATASET_SPLIT = "train" INDEX_SIZE = int(os.getenv("INDEX_SIZE", 400)) # כמה דוגמאות מהדאטהסט לאינדוקס TOPK_NEAREST = 5 # להצגה בגלריה KNN_K_FOR_CLASS = 25 # לשקלול רגשות # Optional SD variations USE_SD_VARIATIONS = True SD_MODEL = "lambdalabs/sd-image-variations-diffusers" # ===================================== # ---------- Load OpenCLIP for image embeddings ---------- try: import open_clip _openclip_model, _, _openclip_preprocess = open_clip.create_model_and_transforms( OPENCLIP_BACKBONE, pretrained=OPENCLIP_PRETRAIN ) _openclip_model = _openclip_model.to(DEVICE).eval() except Exception as e: raise RuntimeError( f"Failed to load OpenCLIP ({OPENCLIP_BACKBONE} / {OPENCLIP_PRETRAIN}). " f"Install 'open_clip_torch' and verify CUDA if available. Error: {e}" ) @torch.inference_mode() def embed_image(img: Image.Image) -> np.ndarray: img = img.convert("RGB") tens = _openclip_preprocess(img).unsqueeze(0).to(DEVICE) feats = _openclip_model.encode_image(tens) feats = F.normalize(feats, dim=-1).squeeze(0).detach().cpu().numpy().astype(np.float32) return feats # shape [D] # ---------- Labels & stress mapping ---------- EMO_MAP = { "anger": "anger", "angry": "anger", "disgust": "disgust", "fear": "fear", "happy": "happy", "happiness": "happy", "neutral": "neutral", "calm": "neutral", "sad": "sad", "sadness": "sad", "surprise": "surprise", "contempt": "contempt", } ALLOWED = set(EMO_MAP.values()) # whitelist קשיח STRESS_WEIGHTS = { "anger": 0.95, "fear": 0.90, "disgust": 0.70, "sad": 0.80, "surprise": 0.55, "neutral": 0.30, "contempt": 0.65, "happy": 0.10, } def _bucket(p: float) -> str: return "Low" if p < 33 else ("Medium" if p < 66 else "High") # ---------- Load dataset & build index ---------- def _extract_label(rec: Dict[str, Any]) -> str: # התאמה לשדות אפשריים בדאטהסט if "label" in rec and rec["label"]: raw = rec["label"] if isinstance(raw, (list, tuple)): raw = raw[0] return str(raw).strip().lower() if "labels" in rec and rec["labels"]: raw = rec["labels"][0] return str(raw).strip().lower() if "qa" in rec and rec["qa"] and isinstance(rec["qa"], list): qa0 = rec["qa"][0] if qa0 and "answer" in qa0: return str(qa0["answer"]).strip().lower() return "" def _map_allowed(lbl: str) -> str: # ממפה לשם סטנדרטי, ומסנן החוצה לא מוכרות mapped = EMO_MAP.get(lbl, lbl) return mapped if mapped in ALLOWED else "" # "" => drop def _load_images_labels_for_index(n: int) -> Tuple[List[Image.Image], List[str]]: ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT) imgs, labels = [], [] n = min(n, len(ds)) for i in range(n): rec = ds[i] im = rec.get("image") if not isinstance(im, Image.Image): continue raw_lbl = _extract_label(rec) mapped = _map_allowed(raw_lbl) if not mapped: continue # זורק תוויות לא מותרות/ריקות imgs.append(im.copy()) labels.append(mapped) return imgs, labels def build_index(imgs: List[Image.Image]) -> Tuple[NearestNeighbors, np.ndarray]: vecs = [embed_image(im) for im in imgs] X = np.stack(vecs, axis=0) nn = NearestNeighbors(metric="cosine", n_neighbors=min(max(TOPK_NEAREST, KNN_K_FOR_CLASS), len(imgs))) nn.fit(X) return nn, X print("Loading dataset & building index (first time only)...") DATASET_IMAGES, DATASET_LABELS = _load_images_labels_for_index(INDEX_SIZE) if len(DATASET_IMAGES) == 0: raise RuntimeError("No images with allowed labels were loaded from the dataset.") NN_MODEL, EMB_MATRIX = build_index(DATASET_IMAGES) print(f"Index ready with {len(DATASET_IMAGES)} images (labels={sorted(set(DATASET_LABELS))}).") # ---------- Nearest & KNN-based classification ---------- def nearest5(pil_img: Image.Image) -> List[Tuple[Image.Image, str]]: q = embed_image(pil_img).reshape(1, -1) n = min(5, len(DATASET_IMAGES)) dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=n) out = [] for rank, (dist, idx) in enumerate(zip(dists[0], idxs[0]), start=1): sim = max(0.0, 1.0 - float(dist)) # cosine distance -> similarity im = DATASET_IMAGES[int(idx)] caption = f"#{rank} sim={sim:.3f} idx={int(idx)}" out.append((im, caption)) return out def knn_probs(pil_img: Image.Image, k: int = KNN_K_FOR_CLASS) -> Dict[str, float]: q = embed_image(pil_img).reshape(1, -1) k = min(k, len(DATASET_IMAGES)) dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=k) sims = 1.0 - dists[0] # higher is better sims = np.maximum(sims, 0.0) votes: Dict[str, float] = {} for sim, idx in zip(sims, idxs[0]): lbl = DATASET_LABELS[int(idx)] if lbl in ALLOWED: votes[lbl] = votes.get(lbl, 0.0) + float(sim) Z = sum(votes.values()) or 1.0 return {k: v / Z for k, v in votes.items()} def emotions_top3(pil_img: Image.Image) -> List[List[Any]]: probs = knn_probs(pil_img) items = sorted(probs.items(), key=lambda kv: kv[1], reverse=True)[:3] table = [] for i, (emo, p) in enumerate(items, start=1): table.append([i, emo, round(100.0 * p, 2)]) # משלימים אם יש פחות מ-3 seen = {r[1] for r in table} for fill in ["neutral", "other"]: if len(table) >= 3: break if fill in ALLOWED and fill not in seen: table.append([len(table)+1, fill, 0.0]) return table def stress_index(pil_img: Image.Image) -> Tuple[str, float]: probs = knn_probs(pil_img) raw = sum(probs.get(k, 0.0) * STRESS_WEIGHTS.get(k, 0.5) for k in ALLOWED) pct = max(0.0, min(100.0, 100.0 * raw)) return f"{pct:.1f}% ({_bucket(pct)})", pct # ---------- Optional: SD image variations ---------- sd_pipe = None if USE_SD_VARIATIONS: try: from diffusers import StableDiffusionImageVariationPipeline sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained( SD_MODEL, torch_dtype=torch.float32 ) sd_pipe = sd_pipe.to(DEVICE) except Exception as e: print(f"[WARN] Could not load {SD_MODEL}. Generation disabled. Error: {e}") sd_pipe = None def generate_one_variation(pil_img: Image.Image, steps: int) -> Image.Image: if sd_pipe is None: raise gr.Error("Image-variation pipeline is not available on this Space.") pil_img = pil_img.convert("RGB") out = sd_pipe(pil_img, guidance_scale=3.0, num_inference_steps=int(steps)).images[0] return out # ===================== GRADIO UI ===================== CSS = ".box { border: 1px solid #e5e7eb; border-radius: 12px; padding: 10px; }" with gr.Blocks(title="Face Emotion & Stress Analyzer — KNN over tukey dataset", css=CSS, fill_height=False) as demo: gr.Markdown( "### Face Emotion & Stress Analyzer — **KNN over `tukey/human_face_emotions_roboflow`**\n" "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n" "- Emotion model: **KNN using labels from `tukey/human_face_emotions_roboflow`**\n" "- Optional SD variations: **lambdalabs/sd-image-variations-diffusers** (1 synthetic only)\n" "- Right column shows nearest 5 images from the dataset (clickable)." ) # ---- Row 1: upload + (top3_emotion_original | stress_original) ---- with gr.Row(): with gr.Column(scale=2): upload_image = gr.Image(label="Upload face image", type="pil") with gr.Column(scale=1): top3_emotion_original = gr.Dataframe( headers=["Rank", "Emotion", "Confidence (%)"], datatype=["number", "str", "number"], interactive=False, label="Top-3 emotions (original image)", value=[] ) with gr.Column(scale=1): stress_original = gr.Label(label="Stress index (original)") gr.Markdown("#### Analyze (no synthetics)") with gr.Row(equal_height=False): # ---------- LEFT COLUMN ---------- with gr.Column(scale=1): with gr.Group(): gr.Markdown("**gen_variations_control** — generate only **one** synthetic") steps = gr.Slider(8, 40, value=12, step=1, label="Diffusion steps (higher=slower/better)") gen_btn = gr.Button("Generate 1 synthetic", variant="primary") picked_synth = gr.Image(label="Synthetic preview") top3_emotion_synth = gr.Dataframe( headers=["Rank", "Emotion", "Confidence (%)"], datatype=["number", "str", "number"], interactive=False, label="top3_emotion_synth", value=[] ) stress_synth = gr.Label(label="stress_synth") # ---------- RIGHT COLUMN ---------- with gr.Column(scale=1): nearest_images_5 = gr.Gallery( label="nearest_images_5 (1-click on 5 examples)", columns=5, rows=1, height=200, allow_preview=False, show_label=True ) tops_emotion_nearest = gr.Dataframe( headers=["Rank", "Emotion", "Confidence (%)"], datatype=["number", "str", "number"], interactive=False, label="tops_emotion_nearest_image", value=[] ) stress_nearest = gr.Label(label="stress_nearest_image") # --------- Hidden states --------- gallery_images_state = gr.State([]) # store PILs gallery_index_state = gr.State([]) # store dataset indexes (ints) # ================= Callbacks ================= def on_upload(img: Image.Image): if img is None: return gr.update(), gr.update(value=""), [], [], [] # original t3 = emotions_top3(img) s_label, _ = stress_index(img) # nearest gallery gal = nearest5(img) # list[(PIL, caption)] gal_imgs = [g[0] for g in gal] gal_caps = [g[1] for g in gal] gallery = [(im, cap) for im, cap in zip(gal_imgs, gal_caps)] return t3, s_label, gallery, gal_imgs, list(range(len(gal_imgs))) upload_image.change( fn=on_upload, inputs=[upload_image], outputs=[top3_emotion_original, stress_original, nearest_images_5, gallery_images_state, gallery_index_state] ) def on_gallery_select(evt: gr.SelectData, imgs: List[Image.Image], idxs: List[int]): if imgs is None or not imgs: return [], "" i = int(evt.index) if evt is not None else 0 i = max(0, min(i, len(imgs)-1)) im = imgs[i] t3 = emotions_top3(im) s_label, _ = stress_index(im) return t3, s_label nearest_images_5.select( fn=on_gallery_select, inputs=[gallery_images_state, gallery_index_state], outputs=[tops_emotion_nearest, stress_nearest] ) def on_generate(img: Image.Image, steps_val: int): if img is None: raise gr.Error("Upload an image first.") if sd_pipe is None: raise gr.Error("Synthetic generation is disabled on this Space.") synth = generate_one_variation(img, steps_val) t3 = emotions_top3(synth) s_label, _ = stress_index(synth) return synth, t3, s_label gen_btn.click( fn=on_generate, inputs=[upload_image, steps], outputs=[picked_synth, top3_emotion_synth, stress_synth] ) if __name__ == "__main__": demo.launch()