attempt to make tag vis work
Browse files
app.py
CHANGED
|
@@ -12,8 +12,6 @@ from torchvision.transforms import InterpolationMode
|
|
| 12 |
import torchvision.transforms.functional as TF
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
|
| 15 |
-
torch.set_grad_enabled(False)
|
| 16 |
-
|
| 17 |
class Fit(torch.nn.Module):
|
| 18 |
def __init__(
|
| 19 |
self,
|
|
@@ -155,11 +153,14 @@ for idx, tag in enumerate(allowed_tags):
|
|
| 155 |
allowed_tags[idx] = tag.replace("_", " ")
|
| 156 |
|
| 157 |
sorted_tag_score = {}
|
|
|
|
|
|
|
| 158 |
|
| 159 |
@spaces.GPU(duration=5)
|
| 160 |
def run_classifier(image, threshold):
|
| 161 |
-
global sorted_tag_score
|
| 162 |
-
|
|
|
|
| 163 |
tensor = transform(img).unsqueeze(0)
|
| 164 |
|
| 165 |
with torch.no_grad():
|
|
@@ -180,10 +181,124 @@ def create_tags(threshold):
|
|
| 180 |
return text_no_impl, filtered_tag_score
|
| 181 |
|
| 182 |
def clear_image():
|
| 183 |
-
global sorted_tag_score
|
|
|
|
| 184 |
sorted_tag_score = {}
|
| 185 |
return "", {}
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
| 188 |
gr.Markdown("""
|
| 189 |
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
|
@@ -219,5 +334,11 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
| 219 |
outputs=[tag_string, label_box]
|
| 220 |
)
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
if __name__ == "__main__":
|
| 223 |
demo.launch()
|
|
|
|
| 12 |
import torchvision.transforms.functional as TF
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
|
|
|
|
|
|
|
| 15 |
class Fit(torch.nn.Module):
|
| 16 |
def __init__(
|
| 17 |
self,
|
|
|
|
| 153 |
allowed_tags[idx] = tag.replace("_", " ")
|
| 154 |
|
| 155 |
sorted_tag_score = {}
|
| 156 |
+
input_image = None
|
| 157 |
+
|
| 158 |
|
| 159 |
@spaces.GPU(duration=5)
|
| 160 |
def run_classifier(image, threshold):
|
| 161 |
+
global sorted_tag_score, input_image
|
| 162 |
+
input_image = image.convert('RGBA')
|
| 163 |
+
img = input_image
|
| 164 |
tensor = transform(img).unsqueeze(0)
|
| 165 |
|
| 166 |
with torch.no_grad():
|
|
|
|
| 181 |
return text_no_impl, filtered_tag_score
|
| 182 |
|
| 183 |
def clear_image():
|
| 184 |
+
global sorted_tag_score, input_image
|
| 185 |
+
input_image = None
|
| 186 |
sorted_tag_score = {}
|
| 187 |
return "", {}
|
| 188 |
|
| 189 |
+
target_tag_index = None
|
| 190 |
+
|
| 191 |
+
# Store hooks and intermediate values
|
| 192 |
+
gradients = {}
|
| 193 |
+
activations = {}
|
| 194 |
+
|
| 195 |
+
def hook_forward(module, input, output):
|
| 196 |
+
activations['value'] = output
|
| 197 |
+
|
| 198 |
+
def hook_backward(module, grad_in, grad_out):
|
| 199 |
+
gradients['value'] = grad_out[0]
|
| 200 |
+
|
| 201 |
+
def cam_inference(target_tag, threshold):
|
| 202 |
+
global input_image, sorted_tag_score, target_tag_index, gradients, activations
|
| 203 |
+
img = input_image
|
| 204 |
+
tensor = transform(img).unsqueeze(0)
|
| 205 |
+
|
| 206 |
+
gradients = {}
|
| 207 |
+
activations = {}
|
| 208 |
+
cam = None
|
| 209 |
+
target_tag_index = None
|
| 210 |
+
|
| 211 |
+
if target_tag:
|
| 212 |
+
if target_tag not in allowed_tags:
|
| 213 |
+
print(f"Warning: Target tag '{target_tag}' not found in allowed tags.")
|
| 214 |
+
target_tag = None
|
| 215 |
+
else:
|
| 216 |
+
target_tag_index = allowed_tags.index(target_tag)
|
| 217 |
+
handle_forward = model.norm.register_forward_hook(hook_forward)
|
| 218 |
+
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
| 219 |
+
|
| 220 |
+
probits = model(tensor)[0].cpu()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if target_tag is not None and target_tag_index is not None:
|
| 224 |
+
model.zero_grad()
|
| 225 |
+
target_score = probits[target_tag_index]
|
| 226 |
+
target_score.backward(retain_graph=True)
|
| 227 |
+
|
| 228 |
+
grads = gradients.get('value')
|
| 229 |
+
acts = activations.get('value')
|
| 230 |
+
|
| 231 |
+
if grads is not None and acts is not None:
|
| 232 |
+
patch_grads = grads
|
| 233 |
+
patch_acts = acts
|
| 234 |
+
|
| 235 |
+
weights = torch.mean(patch_grads, dim=1).squeeze(0)
|
| 236 |
+
|
| 237 |
+
cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
|
| 238 |
+
cam_1d = torch.relu(cam_1d)
|
| 239 |
+
|
| 240 |
+
cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
handle_forward.remove()
|
| 244 |
+
handle_backward.remove()
|
| 245 |
+
gradients = {}
|
| 246 |
+
activations = {}
|
| 247 |
+
|
| 248 |
+
return create_cam_visualization_pil(cam, vis_threshold=threshold)
|
| 249 |
+
|
| 250 |
+
def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
|
| 251 |
+
"""
|
| 252 |
+
Overlays CAM on image and returns a PIL image.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
image_pil: PIL Image (RGB)
|
| 256 |
+
cam: 2D numpy array (activation map)
|
| 257 |
+
alpha: float, blending factor
|
| 258 |
+
vis_threshold: float, minimum normalized CAM value to show color
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
PIL.Image.Image with overlay
|
| 262 |
+
"""
|
| 263 |
+
if cam is None:
|
| 264 |
+
print("CAM is None, skipping visualization.")
|
| 265 |
+
return image_pil
|
| 266 |
+
global input_image
|
| 267 |
+
# Convert to RGB (in case RGBA or others)
|
| 268 |
+
image_pil = input_image
|
| 269 |
+
w, h = image_pil.size
|
| 270 |
+
|
| 271 |
+
# Resize CAM to match image
|
| 272 |
+
cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.BILINEAR))
|
| 273 |
+
|
| 274 |
+
# Normalize CAM to [0, 1]
|
| 275 |
+
cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
|
| 276 |
+
|
| 277 |
+
# Apply threshold mask
|
| 278 |
+
mask = cam_norm >= vis_threshold
|
| 279 |
+
|
| 280 |
+
# Create heatmap using matplotlib colormap
|
| 281 |
+
colormap = cm.get_cmap('jet')
|
| 282 |
+
heatmap_rgba = colormap(cam_norm) # shape: (H, W, 4), values in [0, 1]
|
| 283 |
+
heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
|
| 284 |
+
|
| 285 |
+
# Convert heatmap to PIL image
|
| 286 |
+
heatmap_pil = Image.fromarray(heatmap_rgb).convert("RGB")
|
| 287 |
+
|
| 288 |
+
# Convert images to NumPy for blending
|
| 289 |
+
base_np = np.array(image_pil).astype(np.float32)
|
| 290 |
+
heat_np = np.array(heatmap_pil).astype(np.float32)
|
| 291 |
+
|
| 292 |
+
# Blend only where mask is True
|
| 293 |
+
blended_np = base_np.copy()
|
| 294 |
+
blended_np[mask] = base_np[mask] * (1 - alpha) + heat_np[mask] * alpha
|
| 295 |
+
blended_np = np.clip(blended_np, 0, 255).astype(np.uint8)
|
| 296 |
+
|
| 297 |
+
# Convert back to PIL image
|
| 298 |
+
blended_img = Image.fromarray(blended_np)
|
| 299 |
+
return blended_img
|
| 300 |
+
|
| 301 |
+
|
| 302 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
| 303 |
gr.Markdown("""
|
| 304 |
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
|
|
|
| 334 |
outputs=[tag_string, label_box]
|
| 335 |
)
|
| 336 |
|
| 337 |
+
label_box.select(
|
| 338 |
+
fn=cam_inference,
|
| 339 |
+
inputs=[threshold_slider],
|
| 340 |
+
outputs=[image_input]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
if __name__ == "__main__":
|
| 344 |
demo.launch()
|