Spaces:
Sleeping
Sleeping
| ''' | |
| Example of using Captum Integrated Gradients with a Vision Transformer (ViT) model | |
| to explain image classification predictions. | |
| This example downloads a random image from the web, runs it through a pre-trained | |
| ViT model, and uses Captum to compute and visualize attributions. | |
| IG: It’s like asking the computer not just what’s in the image, | |
| but which parts of the picture convinced it to give that answer. | |
| IG: Integrated Gradients | |
| Like turning up the brightness on a photo and seeing which parts | |
| of the picture made the model confident in its answer. | |
| ''' | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import requests | |
| import random | |
| from io import BytesIO | |
| import numpy as np | |
| from PIL import Image as PILImage | |
| import requests | |
| import random | |
| from PIL import ImageFilter | |
| # Add logging | |
| import logging, os | |
| from logging.handlers import RotatingFileHandler | |
| LOG_DIR = os.path.join(os.path.dirname(__file__), "logs") | |
| os.makedirs(LOG_DIR, exist_ok=True) | |
| logfile = os.path.join(LOG_DIR, "interp.log") | |
| logger = logging.getLogger("vit_and_captum") | |
| if not logger.handlers: | |
| logger.setLevel(logging.INFO) | |
| sh = logging.StreamHandler() | |
| fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8") | |
| fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") | |
| sh.setFormatter(fmt); fh.setFormatter(fmt) | |
| logger.addHandler(sh); logger.addHandler(fh) | |
| # ---- Step 1: Load model ---- | |
| # Using a Vision Transformer (ViT) model from Hugging Face Transformers | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| # Load pre-trained model and processor | |
| model_name = "google/vit-base-patch16-224" | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| # run in eval mode for inference | |
| model.eval() | |
| # ---- Step 2: Load an image ---- | |
| # Function to download a random image from DuckDuckGo | |
| def download_random_image(): | |
| # DuckDuckGo image search for ImageNet-style images | |
| search_terms = ["dog", "cat", "bird", "car", "airplane", "horse", "elephant", "tiger", "lion", "bear"] | |
| term = random.choice(search_terms) | |
| # multiple providers to improve reliability | |
| providers = [ | |
| f"https://source.unsplash.com/224x224/?{term}", | |
| f"https://picsum.photos/seed/{term}/224/224", | |
| f"https://loremflickr.com/224/224/{term}", | |
| # placekitten is a good fallback for cat-like images (serves an image for any request) | |
| f"https://placekitten.com/224/224" | |
| ] | |
| headers = {"User-Agent": "Mozilla/5.0 (compatible; ImageFetcher/1.0)"} | |
| for url in providers: | |
| try: | |
| response = requests.get(url, timeout=10, headers=headers, allow_redirects=True) | |
| if response.status_code != 200: | |
| logger.warning("Provider %s returned status %d", url, response.status_code) | |
| continue | |
| # Try to identify and open image content | |
| try: | |
| img = Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception as img_err: | |
| logger.warning("Failed to parse image from %s: %s", url, img_err) | |
| continue | |
| # Ensure it's exactly 224x224 | |
| try: | |
| img = img.resize((224, 224), Image.Resampling.LANCZOS) | |
| except Exception: | |
| # Fallback if PIL version doesn't have Image.Resampling | |
| img = img.resize((224, 224), Image.LANCZOS) | |
| logger.info("Downloaded random image from %s for term=%s", url, term) | |
| return img | |
| except requests.RequestException as e: | |
| logger.warning("Request failed for %s: %s", url, e) | |
| continue | |
| logger.error("All providers failed; using fallback solid-color image.") | |
| img = Image.new("RGB", (224, 224), color=(128, 128, 128)) | |
| return img | |
| # Download and use a random image | |
| img = download_random_image() | |
| # Preprocess the image to pytorch tensor | |
| inputs = processor(images=img, return_tensors="pt") | |
| # ---- Step 3: Run prediction ---- | |
| with torch.no_grad(): # no gradients needed for inference | |
| outputs = model(**inputs) # inputs is a dict | |
| probs = outputs.logits.softmax(-1) # most probable class | |
| pred_idx = probs.argmax(-1).item() # index of predicted class | |
| logger.info("Predicted %s (idx=%d)", model.config.id2label[pred_idx], pred_idx) | |
| # NEW: show top-k predictions to give context | |
| topk = 5 | |
| topk_vals, topk_idx = torch.topk(probs, k=topk) | |
| topk_vals = topk_vals.squeeze().cpu().numpy() | |
| topk_idx = topk_idx.squeeze().cpu().numpy() | |
| print("Top-{} predictions:".format(topk)) | |
| for v,i in zip(topk_vals, topk_idx): | |
| print(f" {model.config.id2label[int(i)]:30s} {float(v):.4f}") | |
| print("Chosen prediction:", model.config.id2label[pred_idx]) | |
| # ---- Step 4: Captum Integrated Gradients ---- | |
| from captum.attr import IntegratedGradients | |
| # Captum expects a forward function that returns a tensor (not a ModelOutput dataclass) | |
| def forward_func(pixel_values): | |
| # ensure we call the model and return raw logits or probabilities as a Tensor | |
| outputs = model(pixel_values=pixel_values) | |
| # outputs is a ModelOutput dataclass; return the logits tensor | |
| return outputs.logits | |
| # IntegratedGradients should be given the forward function | |
| ig = IntegratedGradients(forward_func) | |
| # Captum needs the inputs to require gradients | |
| input_tensor = inputs["pixel_values"].clone().detach() | |
| input_tensor.requires_grad_(True) | |
| # Now compute attributions for the predicted class index | |
| # (recompute with more steps and ask for convergence delta) | |
| attributions, convergence_delta = ig.attribute( | |
| input_tensor, | |
| target=pred_idx, | |
| n_steps=100, | |
| return_convergence_delta=True, | |
| ) | |
| logger.info("IG convergence delta: %s", convergence_delta) | |
| # ---- Step 5: Visualize attribution heatmap (normalized + overlay) ---- | |
| # aggregate over channels (signed mean keeps sign of contributions) | |
| attr = attributions.squeeze().mean(dim=0).detach().cpu().numpy() | |
| # Normalize to [-1,1] to show positive vs negative contributions with diverging colormap | |
| min_v, max_v = float(attr.min()), float(attr.max()) | |
| norm_denom = max(abs(min_v), abs(max_v)) + 1e-8 | |
| attr_signed = attr / norm_denom # now in approx [-1,1] | |
| # OPTIONAL: smooth heatmap slightly to make overlays more intuitive | |
| try: | |
| heat_pil = PILImage.fromarray(np.uint8((attr_signed + 1) * 127.5)) | |
| heat_pil = heat_pil.filter(ImageFilter.GaussianBlur(radius=1.5)) | |
| attr_signed = (np.array(heat_pil).astype(float) / 127.5) - 1.0 | |
| except Exception: | |
| # If PIL filter not available, continue without smoothing | |
| pass | |
| # Create overlay using a diverging colormap (positive = warm, negative = cool) | |
| plt.figure(figsize=(6,6)) | |
| plt.imshow(img) | |
| plt.imshow(attr_signed, cmap="seismic", alpha=0.45, vmin=-1, vmax=1) | |
| cb = plt.colorbar(fraction=0.046, pad=0.04) | |
| cb.set_label("Signed attribution (normalized)") | |
| plt.title(f"IG overlay — pred: {model.config.id2label[pred_idx]} ({float(probs.squeeze()[pred_idx]):.3f})") | |
| plt.axis("off") | |
| # Show standalone signed heatmap for clearer inspection | |
| plt.figure(figsize=(4,4)) | |
| plt.imshow(attr_signed, cmap="seismic", vmin=-1, vmax=1) | |
| plt.colorbar() | |
| plt.title("Signed IG Attribution (neg=blue, pos=red)") | |
| plt.axis("off") | |
| plt.show() | |
| # Add concise runtime interpretability guidance | |
| def print_interpretability_summary(): | |
| print("\nHow to read the results (quick guide):") | |
| print("- IG signed heatmap: red/warm = supports the predicted class; blue/cool = opposes it.") | |
| print("- Normalize by max-abs when comparing images. Check IG 'convergence delta' — large values mean treat attributions cautiously.") | |
| print("- LIME panel (if used): green/highlighted superpixels indicate locally important regions; background-dominated explanations are a red flag.") | |
| print("- MC Dropout histogram: narrow peak → stable belief; wide/multi-modal → epistemic uncertainty.") | |
| print("- TTA histogram: many flips under small augmentations → fragile/aleatoric sensitivity.") | |
| print("- Predictive entropy: higher → more uncertainty in the full distribution.") | |
| print("- Variation ratio: fraction of samples not matching majority; higher → more disagreement.\n") | |
| print_interpretability_summary() | |