Spaces:
Paused
Paused
| import os, numpy as np | |
| from typing import List, Tuple, Dict, Any | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from sklearn.neighbors import NearestNeighbors | |
| # =============== CONFIG =============== | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Embeddings backbone | |
| OPENCLIP_BACKBONE = "ViT-H-14" | |
| OPENCLIP_PRETRAIN = "laion2B-s32B-b79K" # laion/CLIP-ViT-H-14-laion2B-s32B-b79K | |
| # Dataset (THIS IS YOUR "MODEL" SOURCE NOW) | |
| DATASET_NAME = "tukey/human_face_emotions_roboflow" | |
| DATASET_SPLIT = "train" | |
| INDEX_SIZE = int(os.getenv("INDEX_SIZE", 400)) # כמה דוגמאות מהדאטהסט לאינדוקס | |
| TOPK_NEAREST = 5 # להצגה בגלריה | |
| KNN_K_FOR_CLASS = 25 # לשקלול רגשות | |
| # Optional SD variations | |
| USE_SD_VARIATIONS = True | |
| SD_MODEL = "lambdalabs/sd-image-variations-diffusers" | |
| # ===================================== | |
| # ---------- Load OpenCLIP for image embeddings ---------- | |
| try: | |
| import open_clip | |
| _openclip_model, _, _openclip_preprocess = open_clip.create_model_and_transforms( | |
| OPENCLIP_BACKBONE, pretrained=OPENCLIP_PRETRAIN | |
| ) | |
| _openclip_model = _openclip_model.to(DEVICE).eval() | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to load OpenCLIP ({OPENCLIP_BACKBONE} / {OPENCLIP_PRETRAIN}). " | |
| f"Install 'open_clip_torch' and verify CUDA if available. Error: {e}" | |
| ) | |
| def embed_image(img: Image.Image) -> np.ndarray: | |
| img = img.convert("RGB") | |
| tens = _openclip_preprocess(img).unsqueeze(0).to(DEVICE) | |
| feats = _openclip_model.encode_image(tens) | |
| feats = F.normalize(feats, dim=-1).squeeze(0).detach().cpu().numpy().astype(np.float32) | |
| return feats # shape [D] | |
| # ---------- Labels & stress mapping ---------- | |
| EMO_MAP = { | |
| "anger": "anger", "angry": "anger", | |
| "disgust": "disgust", | |
| "fear": "fear", | |
| "happy": "happy", "happiness": "happy", | |
| "neutral": "neutral", "calm": "neutral", | |
| "sad": "sad", "sadness": "sad", | |
| "surprise": "surprise", | |
| "contempt": "contempt", | |
| } | |
| ALLOWED = set(EMO_MAP.values()) # whitelist קשיח | |
| STRESS_WEIGHTS = { | |
| "anger": 0.95, "fear": 0.90, "disgust": 0.70, "sad": 0.80, | |
| "surprise": 0.55, "neutral": 0.30, "contempt": 0.65, "happy": 0.10, | |
| } | |
| def _bucket(p: float) -> str: | |
| return "Low" if p < 33 else ("Medium" if p < 66 else "High") | |
| # ---------- Load dataset & build index ---------- | |
| def _extract_label(rec: Dict[str, Any]) -> str: | |
| # התאמה לשדות אפשריים בדאטהסט | |
| if "label" in rec and rec["label"]: | |
| raw = rec["label"] | |
| if isinstance(raw, (list, tuple)): raw = raw[0] | |
| return str(raw).strip().lower() | |
| if "labels" in rec and rec["labels"]: | |
| raw = rec["labels"][0] | |
| return str(raw).strip().lower() | |
| if "qa" in rec and rec["qa"] and isinstance(rec["qa"], list): | |
| qa0 = rec["qa"][0] | |
| if qa0 and "answer" in qa0: | |
| return str(qa0["answer"]).strip().lower() | |
| return "" | |
| def _map_allowed(lbl: str) -> str: | |
| # ממפה לשם סטנדרטי, ומסנן החוצה לא מוכרות | |
| mapped = EMO_MAP.get(lbl, lbl) | |
| return mapped if mapped in ALLOWED else "" # "" => drop | |
| def _load_images_labels_for_index(n: int) -> Tuple[List[Image.Image], List[str]]: | |
| ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT) | |
| imgs, labels = [], [] | |
| n = min(n, len(ds)) | |
| for i in range(n): | |
| rec = ds[i] | |
| im = rec.get("image") | |
| if not isinstance(im, Image.Image): | |
| continue | |
| raw_lbl = _extract_label(rec) | |
| mapped = _map_allowed(raw_lbl) | |
| if not mapped: | |
| continue # זורק תוויות לא מותרות/ריקות | |
| imgs.append(im.copy()) | |
| labels.append(mapped) | |
| return imgs, labels | |
| def build_index(imgs: List[Image.Image]) -> Tuple[NearestNeighbors, np.ndarray]: | |
| vecs = [embed_image(im) for im in imgs] | |
| X = np.stack(vecs, axis=0) | |
| nn = NearestNeighbors(metric="cosine", n_neighbors=min(max(TOPK_NEAREST, KNN_K_FOR_CLASS), len(imgs))) | |
| nn.fit(X) | |
| return nn, X | |
| print("Loading dataset & building index (first time only)...") | |
| DATASET_IMAGES, DATASET_LABELS = _load_images_labels_for_index(INDEX_SIZE) | |
| if len(DATASET_IMAGES) == 0: | |
| raise RuntimeError("No images with allowed labels were loaded from the dataset.") | |
| NN_MODEL, EMB_MATRIX = build_index(DATASET_IMAGES) | |
| print(f"Index ready with {len(DATASET_IMAGES)} images (labels={sorted(set(DATASET_LABELS))}).") | |
| # ---------- Nearest & KNN-based classification ---------- | |
| def nearest5(pil_img: Image.Image) -> List[Tuple[Image.Image, str]]: | |
| q = embed_image(pil_img).reshape(1, -1) | |
| n = min(5, len(DATASET_IMAGES)) | |
| dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=n) | |
| out = [] | |
| for rank, (dist, idx) in enumerate(zip(dists[0], idxs[0]), start=1): | |
| sim = max(0.0, 1.0 - float(dist)) # cosine distance -> similarity | |
| im = DATASET_IMAGES[int(idx)] | |
| caption = f"#{rank} sim={sim:.3f} idx={int(idx)}" | |
| out.append((im, caption)) | |
| return out | |
| def knn_probs(pil_img: Image.Image, k: int = KNN_K_FOR_CLASS) -> Dict[str, float]: | |
| q = embed_image(pil_img).reshape(1, -1) | |
| k = min(k, len(DATASET_IMAGES)) | |
| dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=k) | |
| sims = 1.0 - dists[0] # higher is better | |
| sims = np.maximum(sims, 0.0) | |
| votes: Dict[str, float] = {} | |
| for sim, idx in zip(sims, idxs[0]): | |
| lbl = DATASET_LABELS[int(idx)] | |
| if lbl in ALLOWED: | |
| votes[lbl] = votes.get(lbl, 0.0) + float(sim) | |
| Z = sum(votes.values()) or 1.0 | |
| return {k: v / Z for k, v in votes.items()} | |
| def emotions_top3(pil_img: Image.Image) -> List[List[Any]]: | |
| probs = knn_probs(pil_img) | |
| items = sorted(probs.items(), key=lambda kv: kv[1], reverse=True)[:3] | |
| table = [] | |
| for i, (emo, p) in enumerate(items, start=1): | |
| table.append([i, emo, round(100.0 * p, 2)]) | |
| # משלימים אם יש פחות מ-3 | |
| seen = {r[1] for r in table} | |
| for fill in ["neutral", "other"]: | |
| if len(table) >= 3: break | |
| if fill in ALLOWED and fill not in seen: | |
| table.append([len(table)+1, fill, 0.0]) | |
| return table | |
| def stress_index(pil_img: Image.Image) -> Tuple[str, float]: | |
| probs = knn_probs(pil_img) | |
| raw = sum(probs.get(k, 0.0) * STRESS_WEIGHTS.get(k, 0.5) for k in ALLOWED) | |
| pct = max(0.0, min(100.0, 100.0 * raw)) | |
| return f"{pct:.1f}% ({_bucket(pct)})", pct | |
| # ---------- Optional: SD image variations ---------- | |
| sd_pipe = None | |
| if USE_SD_VARIATIONS: | |
| try: | |
| from diffusers import StableDiffusionImageVariationPipeline | |
| sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained( | |
| SD_MODEL, torch_dtype=torch.float32 | |
| ) | |
| sd_pipe = sd_pipe.to(DEVICE) | |
| except Exception as e: | |
| print(f"[WARN] Could not load {SD_MODEL}. Generation disabled. Error: {e}") | |
| sd_pipe = None | |
| def generate_one_variation(pil_img: Image.Image, steps: int) -> Image.Image: | |
| if sd_pipe is None: | |
| raise gr.Error("Image-variation pipeline is not available on this Space.") | |
| pil_img = pil_img.convert("RGB") | |
| out = sd_pipe(pil_img, guidance_scale=3.0, num_inference_steps=int(steps)).images[0] | |
| return out | |
| # ===================== GRADIO UI ===================== | |
| CSS = ".box { border: 1px solid #e5e7eb; border-radius: 12px; padding: 10px; }" | |
| with gr.Blocks(title="Face Emotion & Stress Analyzer — KNN over tukey dataset", css=CSS, fill_height=False) as demo: | |
| gr.Markdown( | |
| "### Face Emotion & Stress Analyzer — **KNN over `tukey/human_face_emotions_roboflow`**\n" | |
| "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n" | |
| "- Emotion model: **KNN using labels from `tukey/human_face_emotions_roboflow`**\n" | |
| "- Optional SD variations: **lambdalabs/sd-image-variations-diffusers** (1 synthetic only)\n" | |
| "- Right column shows nearest 5 images from the dataset (clickable)." | |
| ) | |
| # ---- Row 1: upload + (top3_emotion_original | stress_original) ---- | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| upload_image = gr.Image(label="Upload face image", type="pil") | |
| with gr.Column(scale=1): | |
| top3_emotion_original = gr.Dataframe( | |
| headers=["Rank", "Emotion", "Confidence (%)"], | |
| datatype=["number", "str", "number"], | |
| interactive=False, label="Top-3 emotions (original image)", | |
| value=[] | |
| ) | |
| with gr.Column(scale=1): | |
| stress_original = gr.Label(label="Stress index (original)") | |
| gr.Markdown("#### Analyze (no synthetics)") | |
| with gr.Row(equal_height=False): | |
| # ---------- LEFT COLUMN ---------- | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("**gen_variations_control** — generate only **one** synthetic") | |
| steps = gr.Slider(8, 40, value=12, step=1, label="Diffusion steps (higher=slower/better)") | |
| gen_btn = gr.Button("Generate 1 synthetic", variant="primary") | |
| picked_synth = gr.Image(label="Synthetic preview") | |
| top3_emotion_synth = gr.Dataframe( | |
| headers=["Rank", "Emotion", "Confidence (%)"], | |
| datatype=["number", "str", "number"], | |
| interactive=False, label="top3_emotion_synth", | |
| value=[] | |
| ) | |
| stress_synth = gr.Label(label="stress_synth") | |
| # ---------- RIGHT COLUMN ---------- | |
| with gr.Column(scale=1): | |
| nearest_images_5 = gr.Gallery( | |
| label="nearest_images_5 (1-click on 5 examples)", | |
| columns=5, rows=1, height=200, allow_preview=False, show_label=True | |
| ) | |
| tops_emotion_nearest = gr.Dataframe( | |
| headers=["Rank", "Emotion", "Confidence (%)"], | |
| datatype=["number", "str", "number"], | |
| interactive=False, label="tops_emotion_nearest_image", | |
| value=[] | |
| ) | |
| stress_nearest = gr.Label(label="stress_nearest_image") | |
| # --------- Hidden states --------- | |
| gallery_images_state = gr.State([]) # store PILs | |
| gallery_index_state = gr.State([]) # store dataset indexes (ints) | |
| # ================= Callbacks ================= | |
| def on_upload(img: Image.Image): | |
| if img is None: | |
| return gr.update(), gr.update(value=""), [], [], [] | |
| # original | |
| t3 = emotions_top3(img) | |
| s_label, _ = stress_index(img) | |
| # nearest gallery | |
| gal = nearest5(img) # list[(PIL, caption)] | |
| gal_imgs = [g[0] for g in gal] | |
| gal_caps = [g[1] for g in gal] | |
| gallery = [(im, cap) for im, cap in zip(gal_imgs, gal_caps)] | |
| return t3, s_label, gallery, gal_imgs, list(range(len(gal_imgs))) | |
| upload_image.change( | |
| fn=on_upload, | |
| inputs=[upload_image], | |
| outputs=[top3_emotion_original, stress_original, nearest_images_5, gallery_images_state, gallery_index_state] | |
| ) | |
| def on_gallery_select(evt: gr.SelectData, imgs: List[Image.Image], idxs: List[int]): | |
| if imgs is None or not imgs: | |
| return [], "" | |
| i = int(evt.index) if evt is not None else 0 | |
| i = max(0, min(i, len(imgs)-1)) | |
| im = imgs[i] | |
| t3 = emotions_top3(im) | |
| s_label, _ = stress_index(im) | |
| return t3, s_label | |
| nearest_images_5.select( | |
| fn=on_gallery_select, | |
| inputs=[gallery_images_state, gallery_index_state], | |
| outputs=[tops_emotion_nearest, stress_nearest] | |
| ) | |
| def on_generate(img: Image.Image, steps_val: int): | |
| if img is None: | |
| raise gr.Error("Upload an image first.") | |
| if sd_pipe is None: | |
| raise gr.Error("Synthetic generation is disabled on this Space.") | |
| synth = generate_one_variation(img, steps_val) | |
| t3 = emotions_top3(synth) | |
| s_label, _ = stress_index(synth) | |
| return synth, t3, s_label | |
| gen_btn.click( | |
| fn=on_generate, | |
| inputs=[upload_image, steps], | |
| outputs=[picked_synth, top3_emotion_synth, stress_synth] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |