ViT_timm_interp / vit_and_captum.py
Skier8402's picture
Upload 2 files
039c47c verified
'''
Example of using Captum Integrated Gradients with a Vision Transformer (ViT) model
to explain image classification predictions.
This example downloads a random image from the web, runs it through a pre-trained
ViT model, and uses Captum to compute and visualize attributions.
IG: It’s like asking the computer not just what’s in the image,
but which parts of the picture convinced it to give that answer.
IG: Integrated Gradients
Like turning up the brightness on a photo and seeing which parts
of the picture made the model confident in its answer.
'''
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
import random
from io import BytesIO
import numpy as np
from PIL import Image as PILImage
import requests
import random
from PIL import ImageFilter
# Add logging
import logging, os
from logging.handlers import RotatingFileHandler
LOG_DIR = os.path.join(os.path.dirname(__file__), "logs")
os.makedirs(LOG_DIR, exist_ok=True)
logfile = os.path.join(LOG_DIR, "interp.log")
logger = logging.getLogger("vit_and_captum")
if not logger.handlers:
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8")
fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
sh.setFormatter(fmt); fh.setFormatter(fmt)
logger.addHandler(sh); logger.addHandler(fh)
# ---- Step 1: Load model ----
# Using a Vision Transformer (ViT) model from Hugging Face Transformers
from transformers import ViTForImageClassification, ViTImageProcessor
# Load pre-trained model and processor
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)
# run in eval mode for inference
model.eval()
# ---- Step 2: Load an image ----
# Function to download a random image from DuckDuckGo
def download_random_image():
# DuckDuckGo image search for ImageNet-style images
search_terms = ["dog", "cat", "bird", "car", "airplane", "horse", "elephant", "tiger", "lion", "bear"]
term = random.choice(search_terms)
# multiple providers to improve reliability
providers = [
f"https://source.unsplash.com/224x224/?{term}",
f"https://picsum.photos/seed/{term}/224/224",
f"https://loremflickr.com/224/224/{term}",
# placekitten is a good fallback for cat-like images (serves an image for any request)
f"https://placekitten.com/224/224"
]
headers = {"User-Agent": "Mozilla/5.0 (compatible; ImageFetcher/1.0)"}
for url in providers:
try:
response = requests.get(url, timeout=10, headers=headers, allow_redirects=True)
if response.status_code != 200:
logger.warning("Provider %s returned status %d", url, response.status_code)
continue
# Try to identify and open image content
try:
img = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as img_err:
logger.warning("Failed to parse image from %s: %s", url, img_err)
continue
# Ensure it's exactly 224x224
try:
img = img.resize((224, 224), Image.Resampling.LANCZOS)
except Exception:
# Fallback if PIL version doesn't have Image.Resampling
img = img.resize((224, 224), Image.LANCZOS)
logger.info("Downloaded random image from %s for term=%s", url, term)
return img
except requests.RequestException as e:
logger.warning("Request failed for %s: %s", url, e)
continue
logger.error("All providers failed; using fallback solid-color image.")
img = Image.new("RGB", (224, 224), color=(128, 128, 128))
return img
# Download and use a random image
img = download_random_image()
# Preprocess the image to pytorch tensor
inputs = processor(images=img, return_tensors="pt")
# ---- Step 3: Run prediction ----
with torch.no_grad(): # no gradients needed for inference
outputs = model(**inputs) # inputs is a dict
probs = outputs.logits.softmax(-1) # most probable class
pred_idx = probs.argmax(-1).item() # index of predicted class
logger.info("Predicted %s (idx=%d)", model.config.id2label[pred_idx], pred_idx)
# NEW: show top-k predictions to give context
topk = 5
topk_vals, topk_idx = torch.topk(probs, k=topk)
topk_vals = topk_vals.squeeze().cpu().numpy()
topk_idx = topk_idx.squeeze().cpu().numpy()
print("Top-{} predictions:".format(topk))
for v,i in zip(topk_vals, topk_idx):
print(f" {model.config.id2label[int(i)]:30s} {float(v):.4f}")
print("Chosen prediction:", model.config.id2label[pred_idx])
# ---- Step 4: Captum Integrated Gradients ----
from captum.attr import IntegratedGradients
# Captum expects a forward function that returns a tensor (not a ModelOutput dataclass)
def forward_func(pixel_values):
# ensure we call the model and return raw logits or probabilities as a Tensor
outputs = model(pixel_values=pixel_values)
# outputs is a ModelOutput dataclass; return the logits tensor
return outputs.logits
# IntegratedGradients should be given the forward function
ig = IntegratedGradients(forward_func)
# Captum needs the inputs to require gradients
input_tensor = inputs["pixel_values"].clone().detach()
input_tensor.requires_grad_(True)
# Now compute attributions for the predicted class index
# (recompute with more steps and ask for convergence delta)
attributions, convergence_delta = ig.attribute(
input_tensor,
target=pred_idx,
n_steps=100,
return_convergence_delta=True,
)
logger.info("IG convergence delta: %s", convergence_delta)
# ---- Step 5: Visualize attribution heatmap (normalized + overlay) ----
# aggregate over channels (signed mean keeps sign of contributions)
attr = attributions.squeeze().mean(dim=0).detach().cpu().numpy()
# Normalize to [-1,1] to show positive vs negative contributions with diverging colormap
min_v, max_v = float(attr.min()), float(attr.max())
norm_denom = max(abs(min_v), abs(max_v)) + 1e-8
attr_signed = attr / norm_denom # now in approx [-1,1]
# OPTIONAL: smooth heatmap slightly to make overlays more intuitive
try:
heat_pil = PILImage.fromarray(np.uint8((attr_signed + 1) * 127.5))
heat_pil = heat_pil.filter(ImageFilter.GaussianBlur(radius=1.5))
attr_signed = (np.array(heat_pil).astype(float) / 127.5) - 1.0
except Exception:
# If PIL filter not available, continue without smoothing
pass
# Create overlay using a diverging colormap (positive = warm, negative = cool)
plt.figure(figsize=(6,6))
plt.imshow(img)
plt.imshow(attr_signed, cmap="seismic", alpha=0.45, vmin=-1, vmax=1)
cb = plt.colorbar(fraction=0.046, pad=0.04)
cb.set_label("Signed attribution (normalized)")
plt.title(f"IG overlay — pred: {model.config.id2label[pred_idx]} ({float(probs.squeeze()[pred_idx]):.3f})")
plt.axis("off")
# Show standalone signed heatmap for clearer inspection
plt.figure(figsize=(4,4))
plt.imshow(attr_signed, cmap="seismic", vmin=-1, vmax=1)
plt.colorbar()
plt.title("Signed IG Attribution (neg=blue, pos=red)")
plt.axis("off")
plt.show()
# Add concise runtime interpretability guidance
def print_interpretability_summary():
print("\nHow to read the results (quick guide):")
print("- IG signed heatmap: red/warm = supports the predicted class; blue/cool = opposes it.")
print("- Normalize by max-abs when comparing images. Check IG 'convergence delta' — large values mean treat attributions cautiously.")
print("- LIME panel (if used): green/highlighted superpixels indicate locally important regions; background-dominated explanations are a red flag.")
print("- MC Dropout histogram: narrow peak → stable belief; wide/multi-modal → epistemic uncertainty.")
print("- TTA histogram: many flips under small augmentations → fragile/aleatoric sensitivity.")
print("- Predictive entropy: higher → more uncertainty in the full distribution.")
print("- Variation ratio: fraction of samples not matching majority; higher → more disagreement.\n")
print_interpretability_summary()