Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,18 +1,21 @@
|
|
| 1 |
-
|
| 2 |
import torch
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
| 6 |
import torchvision.transforms.functional as TF
|
| 7 |
from matplotlib import colormaps
|
| 8 |
from transformers import AutoModel
|
| 9 |
-
import os
|
| 10 |
|
| 11 |
# ----------------------------
|
| 12 |
# Configuration
|
| 13 |
# ----------------------------
|
| 14 |
-
#
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
PATCH_SIZE = 16
|
| 17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
|
|
@@ -21,32 +24,49 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
| 21 |
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 22 |
|
| 23 |
# ----------------------------
|
| 24 |
-
# Model Loading (
|
| 25 |
# ----------------------------
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
try:
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
print(f"β
Model loaded successfully on device: {DEVICE}")
|
| 36 |
-
return model
|
| 37 |
except Exception as e:
|
| 38 |
-
print(f"β Failed to load model: {e}")
|
| 39 |
-
# This will display a clear error message in the Gradio interface
|
| 40 |
raise gr.Error(
|
| 41 |
-
f"Could not load model '{
|
| 42 |
-
"
|
| 43 |
-
"and set
|
| 44 |
f"Original error: {e}"
|
| 45 |
)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
model
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# ----------------------------
|
| 51 |
# Helper Functions
|
| 52 |
# ----------------------------
|
|
@@ -85,9 +105,11 @@ def generate_pca_visuals(
|
|
| 85 |
resolution: int,
|
| 86 |
cmap_name: str,
|
| 87 |
overlay_alpha: float,
|
|
|
|
| 88 |
progress=gr.Progress(track_tqdm=True)
|
| 89 |
):
|
| 90 |
"""Main function to generate PCA visuals."""
|
|
|
|
| 91 |
if model is None:
|
| 92 |
raise gr.Error("DINOv3 model is not available. Check the startup logs.")
|
| 93 |
if image_pil is None:
|
|
@@ -105,9 +127,8 @@ def generate_pca_visuals(
|
|
| 105 |
progress(0.5, desc="π¦ Extracting features with DINOv3...")
|
| 106 |
outputs = model(t_norm)
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
-
n_special_tokens = 5 # 1 [CLS] token + 4 register tokens for ViT-H/16+
|
| 111 |
patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
|
| 112 |
|
| 113 |
# 3. PCA Calculation
|
|
@@ -115,8 +136,7 @@ def generate_pca_visuals(
|
|
| 115 |
X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
|
| 116 |
U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
|
| 117 |
|
| 118 |
-
#
|
| 119 |
-
# This prevents the colors from randomly inverting on different runs.
|
| 120 |
for i in range(V.shape[1]):
|
| 121 |
max_abs_idx = torch.argmax(torch.abs(V[:, i]))
|
| 122 |
if V[max_abs_idx, i] < 0:
|
|
@@ -134,7 +154,6 @@ def generate_pca_visuals(
|
|
| 134 |
)
|
| 135 |
|
| 136 |
# 5. Create Visualizations
|
| 137 |
-
# This part should now work correctly as `scores` has the right shape (Hp*Wp, 3)
|
| 138 |
pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
|
| 139 |
pc1_image_raw = colorize(pc1_map, cmap_name)
|
| 140 |
|
|
@@ -155,10 +174,10 @@ def generate_pca_visuals(
|
|
| 155 |
# ----------------------------
|
| 156 |
# Gradio Interface
|
| 157 |
# ----------------------------
|
| 158 |
-
with gr.Blocks(theme=gr.themes.Soft(), title="
|
| 159 |
gr.Markdown(
|
| 160 |
"""
|
| 161 |
-
#
|
| 162 |
Upload an image to visualize the principal components of its patch features.
|
| 163 |
This reveals the main axes of semantic variation within the image as understood by the model.
|
| 164 |
"""
|
|
@@ -166,7 +185,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait π¦
|
|
| 166 |
|
| 167 |
with gr.Row():
|
| 168 |
with gr.Column(scale=2):
|
| 169 |
-
# Added a default image URL for convenience
|
| 170 |
input_image = gr.Image(type="pil", label="Upload Image", value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg")
|
| 171 |
|
| 172 |
with gr.Accordion("βοΈ Visualization Controls", open=True):
|
|
@@ -175,6 +193,12 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait π¦
|
|
| 175 |
label="Processing Resolution",
|
| 176 |
info="Higher values capture more detail but are slower."
|
| 177 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
cmap_dropdown = gr.Dropdown(
|
| 179 |
['viridis', 'magma', 'inferno', 'plasma', 'cividis', 'jet'],
|
| 180 |
value='viridis',
|
|
@@ -201,7 +225,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait π¦
|
|
| 201 |
|
| 202 |
run_button.click(
|
| 203 |
fn=generate_pca_visuals,
|
| 204 |
-
inputs=[input_image, resolution_slider, cmap_dropdown, alpha_slider],
|
| 205 |
outputs=[output_pc1, output_rgb, output_variance, output_blended, output_processed]
|
| 206 |
)
|
| 207 |
|
|
|
|
| 1 |
+
import os
|
| 2 |
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image
|
| 7 |
import torchvision.transforms.functional as TF
|
| 8 |
from matplotlib import colormaps
|
| 9 |
from transformers import AutoModel
|
|
|
|
| 10 |
|
| 11 |
# ----------------------------
|
| 12 |
# Configuration
|
| 13 |
# ----------------------------
|
| 14 |
+
# Define available models
|
| 15 |
+
DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
|
| 16 |
+
ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
|
| 17 |
+
AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
|
| 18 |
+
|
| 19 |
PATCH_SIZE = 16
|
| 20 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
|
|
|
|
| 24 |
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 25 |
|
| 26 |
# ----------------------------
|
| 27 |
+
# Model Loading (with caching)
|
| 28 |
# ----------------------------
|
| 29 |
+
_model_cache = {}
|
| 30 |
+
_current_model_id = None
|
| 31 |
+
model = None # global reference
|
| 32 |
+
|
| 33 |
+
def load_model_from_hub(model_id: str):
|
| 34 |
+
"""Loads a DINOv3 model from the Hugging Face Hub."""
|
| 35 |
+
print(f"Loading model '{model_id}' from Hugging Face Hub...")
|
| 36 |
try:
|
| 37 |
+
token = os.environ.get("HF_TOKEN") # optional, for gated models
|
| 38 |
+
mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
|
| 39 |
+
mdl.to(DEVICE).eval()
|
| 40 |
+
print(f"β
Model '{model_id}' loaded successfully on device: {DEVICE}")
|
| 41 |
+
return mdl
|
|
|
|
|
|
|
| 42 |
except Exception as e:
|
| 43 |
+
print(f"β Failed to load model '{model_id}': {e}")
|
|
|
|
| 44 |
raise gr.Error(
|
| 45 |
+
f"Could not load model '{model_id}'. "
|
| 46 |
+
"If the model is gated, please accept the terms on its Hugging Face page "
|
| 47 |
+
"and set HF_TOKEN in your environment. "
|
| 48 |
f"Original error: {e}"
|
| 49 |
)
|
| 50 |
|
| 51 |
+
def get_model(model_id: str):
|
| 52 |
+
"""Return a cached model if available, otherwise load and cache it."""
|
| 53 |
+
if model_id in _model_cache:
|
| 54 |
+
return _model_cache[model_id]
|
| 55 |
+
mdl = load_model_from_hub(model_id)
|
| 56 |
+
_model_cache[model_id] = mdl
|
| 57 |
+
return mdl
|
| 58 |
+
|
| 59 |
+
# Load the default model at startup
|
| 60 |
+
model = get_model(DEFAULT_MODEL_ID)
|
| 61 |
+
_current_model_id = DEFAULT_MODEL_ID
|
| 62 |
+
|
| 63 |
+
def _ensure_model(model_id: str):
|
| 64 |
+
"""Ensure the global 'model' matches the dropdown selection."""
|
| 65 |
+
global model, _current_model_id
|
| 66 |
+
if model_id != _current_model_id:
|
| 67 |
+
model = get_model(model_id)
|
| 68 |
+
_current_model_id = model_id
|
| 69 |
+
|
| 70 |
# ----------------------------
|
| 71 |
# Helper Functions
|
| 72 |
# ----------------------------
|
|
|
|
| 105 |
resolution: int,
|
| 106 |
cmap_name: str,
|
| 107 |
overlay_alpha: float,
|
| 108 |
+
model_id: str,
|
| 109 |
progress=gr.Progress(track_tqdm=True)
|
| 110 |
):
|
| 111 |
"""Main function to generate PCA visuals."""
|
| 112 |
+
_ensure_model(model_id)
|
| 113 |
if model is None:
|
| 114 |
raise gr.Error("DINOv3 model is not available. Check the startup logs.")
|
| 115 |
if image_pil is None:
|
|
|
|
| 127 |
progress(0.5, desc="π¦ Extracting features with DINOv3...")
|
| 128 |
outputs = model(t_norm)
|
| 129 |
|
| 130 |
+
# The model output includes a [CLS] token AND 4 register tokens.
|
| 131 |
+
n_special_tokens = 5
|
|
|
|
| 132 |
patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
|
| 133 |
|
| 134 |
# 3. PCA Calculation
|
|
|
|
| 136 |
X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
|
| 137 |
U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
|
| 138 |
|
| 139 |
+
# Stabilize the signs of the eigenvectors for deterministic output.
|
|
|
|
| 140 |
for i in range(V.shape[1]):
|
| 141 |
max_abs_idx = torch.argmax(torch.abs(V[:, i]))
|
| 142 |
if V[max_abs_idx, i] < 0:
|
|
|
|
| 154 |
)
|
| 155 |
|
| 156 |
# 5. Create Visualizations
|
|
|
|
| 157 |
pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
|
| 158 |
pc1_image_raw = colorize(pc1_map, cmap_name)
|
| 159 |
|
|
|
|
| 174 |
# ----------------------------
|
| 175 |
# Gradio Interface
|
| 176 |
# ----------------------------
|
| 177 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="π¦ DINOv3 PCA Explorer") as demo:
|
| 178 |
gr.Markdown(
|
| 179 |
"""
|
| 180 |
+
# π¦ DINOv3 PCA Explorer
|
| 181 |
Upload an image to visualize the principal components of its patch features.
|
| 182 |
This reveals the main axes of semantic variation within the image as understood by the model.
|
| 183 |
"""
|
|
|
|
| 185 |
|
| 186 |
with gr.Row():
|
| 187 |
with gr.Column(scale=2):
|
|
|
|
| 188 |
input_image = gr.Image(type="pil", label="Upload Image", value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg")
|
| 189 |
|
| 190 |
with gr.Accordion("βοΈ Visualization Controls", open=True):
|
|
|
|
| 193 |
label="Processing Resolution",
|
| 194 |
info="Higher values capture more detail but are slower."
|
| 195 |
)
|
| 196 |
+
model_choice = gr.Dropdown(
|
| 197 |
+
choices=AVAILABLE_MODELS,
|
| 198 |
+
value=DEFAULT_MODEL_ID,
|
| 199 |
+
label="Backbone (DINOv3)",
|
| 200 |
+
info="ViT-S/16+ is smaller & faster; ViT-H/16+ is larger.",
|
| 201 |
+
)
|
| 202 |
cmap_dropdown = gr.Dropdown(
|
| 203 |
['viridis', 'magma', 'inferno', 'plasma', 'cividis', 'jet'],
|
| 204 |
value='viridis',
|
|
|
|
| 225 |
|
| 226 |
run_button.click(
|
| 227 |
fn=generate_pca_visuals,
|
| 228 |
+
inputs=[input_image, resolution_slider, cmap_dropdown, alpha_slider, model_choice],
|
| 229 |
outputs=[output_pc1, output_rgb, output_variance, output_blended, output_processed]
|
| 230 |
)
|
| 231 |
|