Update app.py
Browse files
app.py
CHANGED
|
@@ -276,29 +276,18 @@ def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
|
|
| 276 |
# Normalize CAM to [0, 1]
|
| 277 |
cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
|
| 278 |
|
| 279 |
-
# Apply threshold mask
|
| 280 |
-
mask = cam_norm >= vis_threshold
|
| 281 |
-
|
| 282 |
# Create heatmap using matplotlib colormap
|
| 283 |
colormap = cm.get_cmap('jet')
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
# Convert heatmap to PIL image
|
| 288 |
-
heatmap_pil = Image.fromarray(heatmap_rgb).convert("RGB")
|
| 289 |
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
heat_np = np.array(heatmap_pil).astype(np.float32)
|
| 293 |
|
| 294 |
-
#
|
| 295 |
-
|
| 296 |
-
blended_np[mask] = base_np[mask] * (1 - alpha) + heat_np[mask] * alpha
|
| 297 |
-
blended_np = np.clip(blended_np, 0, 255).astype(np.uint8)
|
| 298 |
|
| 299 |
-
|
| 300 |
-
blended_img = Image.fromarray(blended_np)
|
| 301 |
-
return blended_img
|
| 302 |
|
| 303 |
|
| 304 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
|
|
| 276 |
# Normalize CAM to [0, 1]
|
| 277 |
cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
|
| 278 |
|
|
|
|
|
|
|
|
|
|
| 279 |
# Create heatmap using matplotlib colormap
|
| 280 |
colormap = cm.get_cmap('jet')
|
| 281 |
+
cam_colored = colormap(cam_norm)[:, :, :3] # RGB
|
| 282 |
+
cam_alpha = (cam_norm >= vis_threshold).astype(np.float32) * alpha # Alpha mask
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
cam_rgba = np.dstack((cam_colored, cam_alpha)) # Shape: (H, W, 4)
|
| 285 |
+
cam_image = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
|
|
|
|
| 286 |
|
| 287 |
+
# Composite over original
|
| 288 |
+
composite = Image.alpha_composite(image_pil, cam_image)
|
|
|
|
|
|
|
| 289 |
|
| 290 |
+
return composite
|
|
|
|
|
|
|
| 291 |
|
| 292 |
|
| 293 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|