Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import requests | |
| import random | |
| from io import BytesIO | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| from lime import lime_image | |
| from skimage.segmentation import slic, mark_boundaries | |
| # Add logging | |
| import logging, os | |
| from logging.handlers import RotatingFileHandler | |
| LOG_DIR = os.path.join(os.path.dirname(__file__), "logs") | |
| os.makedirs(LOG_DIR, exist_ok=True) | |
| logfile = os.path.join(LOG_DIR, "interp.log") | |
| logger = logging.getLogger("vit_lime_uncertainty") | |
| if not logger.handlers: | |
| logger.setLevel(logging.INFO) | |
| sh = logging.StreamHandler() | |
| fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8") | |
| fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") | |
| sh.setFormatter(fmt); fh.setFormatter(fmt) | |
| logger.addHandler(sh); logger.addHandler(fh) | |
| # ---- Step 1: Load model & processor ---- | |
| model_name = "google/vit-base-patch16-224" | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| model.eval() | |
| # ---- Step 2: Robust random image downloader (multiple providers + fallback) ---- | |
| def download_random_image(size=(224, 224)): | |
| search_terms = ["dog", "cat", "bird", "car", "airplane", "horse", "elephant", "tiger", "lion", "bear"] | |
| term = random.choice(search_terms) | |
| providers = [ | |
| f"https://source.unsplash.com/{size[0]}x{size[1]}/?{term}", | |
| f"https://picsum.photos/seed/{term}/{size[0]}/{size[1]}", | |
| f"https://loremflickr.com/{size[0]}/{size[1]}/{term}", | |
| f"https://placekitten.com/{size[0]}/{size[1]}" | |
| ] | |
| headers = {"User-Agent": "Mozilla/5.0 (compatible; ImageFetcher/1.0)"} | |
| for url in providers: | |
| try: | |
| r = requests.get(url, timeout=10, headers=headers, allow_redirects=True) | |
| if r.status_code != 200: | |
| logger.warning("Provider %s returned status %d", url, r.status_code) | |
| continue | |
| try: | |
| img = Image.open(BytesIO(r.content)).convert("RGB") | |
| except Exception as e: | |
| logger.warning("Failed to open image from %s: %s", url, e) | |
| continue | |
| try: | |
| img = img.resize(size, Image.Resampling.LANCZOS) | |
| except Exception: | |
| img = img.resize(size, Image.LANCZOS) | |
| logger.info("Downloaded image for '%s' from %s", term, url) | |
| return img | |
| except requests.RequestException as e: | |
| logger.warning("Request exception %s for %s", e, url) | |
| continue | |
| logger.error("All providers failed; using fallback solid-color image.") | |
| return Image.new("RGB", size, color=(128, 128, 128)) | |
| # ---- Step 3: Classifier function for LIME ---- | |
| def classifier_fn(images_batch): | |
| """ | |
| images_batch: list or numpy array of images with shape (N, H, W, 3), | |
| values in [0,255] or uint8. Return numpy array (N, num_classes) of probabilities. | |
| """ | |
| # transformer processor accepts numpy arrays directly | |
| if isinstance(images_batch, np.ndarray): | |
| imgs = [img.astype(np.uint8) for img in images_batch] | |
| else: | |
| imgs = images_batch | |
| inputs = processor(images=imgs, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy() | |
| return probs | |
| # ---- Step 4: Run LIME multiple times to estimate uncertainty ---- | |
| def lime_explanations_with_uncertainty(img_pil, n_runs=6, num_samples=1000, segments_kwargs=None): | |
| if segments_kwargs is None: | |
| segments_kwargs = {"n_segments": 50, "compactness": 10} | |
| explainer = lime_image.LimeImageExplainer() | |
| img_np = np.array(img_pil) # H,W,3 uint8 | |
| run_maps = [] | |
| for run in range(n_runs): | |
| logger.info("LIME run %d/%d (num_samples=%d)", run+1, n_runs, num_samples) | |
| # segmentation function to ensure reproducible-ish segments per run | |
| segmentation_fn = lambda x: slic(x, start_label=0, **segments_kwargs) | |
| explanation = explainer.explain_instance( | |
| img_np, | |
| classifier_fn=classifier_fn, | |
| top_labels=5, | |
| hide_color=0, | |
| num_samples=num_samples, | |
| segmentation_fn=segmentation_fn | |
| ) | |
| preds = classifier_fn(np.expand_dims(img_np, 0)) | |
| pred_label = int(preds[0].argmax()) | |
| local_exp = dict(explanation.local_exp)[pred_label] | |
| segments = explanation.segments # shape (H,W) of segment ids | |
| attr_map = np.zeros(segments.shape, dtype=float) | |
| for seg_id, weight in local_exp: | |
| attr_map[segments == seg_id] = weight | |
| run_maps.append(attr_map) | |
| runs_stack = np.stack(run_maps, axis=0) | |
| mean_attr = runs_stack.mean(axis=0) | |
| std_attr = runs_stack.std(axis=0) | |
| logger.info("Completed %d LIME runs, mean/std shapes: %s / %s", n_runs, mean_attr.shape, std_attr.shape) | |
| # compute segments once for overlay (use same segmentation kwargs) | |
| segments_final = slic(img_np, start_label=0, **segments_kwargs) | |
| return img_np, mean_attr, std_attr, segments_final, pred_label, preds.squeeze() | |
| # ---- Step 5: Visualize results ---- | |
| def plot_mean_and_uncertainty(img_np, mean_attr, std_attr, segments, pred_label, probs, cmap_mean="jet", cmap_unc="hot"): | |
| # normalize for display (center mean at 0) | |
| def normalize(x): | |
| mn, mx = x.min(), x.max() | |
| return (x - mn) / (mx - mn + 1e-8) | |
| mean_norm = normalize(mean_attr) | |
| std_norm = normalize(std_attr) | |
| # show label + prob in title | |
| pred_name = model.config.id2label[int(pred_label)] | |
| pred_prob = float(probs[int(pred_label)]) | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 9)) | |
| axes = axes.flatten() | |
| axes[0].imshow(img_np) | |
| axes[0].set_title("Original image") | |
| axes[0].axis("off") | |
| # overlay mean attribution with segment boundaries | |
| overlay = img_np.copy().astype(float) / 255.0 | |
| axes[1].imshow(mark_boundaries(overlay, segments, color=(1,1,0))) | |
| im1 = axes[1].imshow(mean_norm, cmap=cmap_mean, alpha=0.5) | |
| axes[1].set_title(f"Mean attribution (overlay)\npred: {pred_name} ({pred_prob:.3f})") | |
| axes[1].axis("off") | |
| fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04) | |
| # uncertainty map and contour where std is high | |
| im2 = axes[2].imshow(std_norm, cmap=cmap_unc) | |
| axes[2].set_title("Uncertainty (std)") | |
| axes[2].axis("off") | |
| fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04) | |
| # histogram of mean attribution values | |
| axes[3].hist(mean_attr.ravel(), bins=50, color="C0") | |
| axes[3].set_title("Distribution of mean attribution") | |
| # histogram of uncertainty values | |
| axes[4].hist(std_attr.ravel(), bins=50, color="C1") | |
| axes[4].set_title("Distribution of attribution std (uncertainty)") | |
| # show uncertainty contour over image (high uncertainty regions) | |
| thresh = np.percentile(std_attr, 90) | |
| contour_mask = std_attr >= thresh | |
| axes[5].imshow(img_np) | |
| axes[5].imshow(np.ma.masked_where(~contour_mask, contour_mask), cmap="Reds", alpha=0.45) | |
| axes[5].set_title(f"Top-10% uncertainty (threshold={thresh:.3f})") | |
| axes[5].axis("off") | |
| plt.tight_layout() | |
| plt.show() | |
| # ---- Main: run example ---- | |
| if __name__ == "__main__": | |
| logger.info("Script started") | |
| img = download_random_image() | |
| img_np, mean_attr, std_attr, segments, pred_label, probs = lime_explanations_with_uncertainty( | |
| img_pil=img, | |
| n_runs=6, # increase for better uncertainty estimates (longer) | |
| num_samples=1000, # LIME samples per run | |
| segments_kwargs={"n_segments": 60, "compactness": 9} | |
| ) | |
| logger.info("Plotting results and finishing") | |
| plot_mean_and_uncertainty(img_np, mean_attr, std_attr, segments, pred_label, probs) | |
| # Add concise runtime interpretability guidance | |
| def print_interpretability_summary(): | |
| print("\nHow to read the results (quick guide):") | |
| print("- LIME panel: green/highlighted superpixels are locally important for the predicted class; if background dominates, that's a red flag.") | |
| print("- LIME uncertainty (std): high std regions indicate unstable explanations across runs.") | |
| print("- MC Dropout histogram: narrow peak β stable belief; wide/multi-modal β epistemic uncertainty.") | |
| print("- TTA histogram: if small flips/crops cause big swings, prediction depends on fragile cues (aleatoric-ish sensitivity).") | |
| print("- Predictive entropy: higher means more uncertainty in the class distribution.") | |
| print("- Variation ratio: fraction of samples not in the majority class; higher β more disagreement.\n") | |
| print_interpretability_summary() | |
| logger.info("Script finished") | |