Update app.py
Browse files
app.py
CHANGED
|
@@ -242,13 +242,16 @@ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
|
|
| 242 |
if cam is None:
|
| 243 |
return image_pil
|
| 244 |
w, h = image_pil.size
|
|
|
|
| 245 |
|
| 246 |
# Resize CAM to match image
|
| 247 |
-
cam_resized = np.array(Image.fromarray(cam).resize((
|
| 248 |
|
| 249 |
# Normalize CAM to [0, 1]
|
| 250 |
cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
|
| 251 |
|
|
|
|
|
|
|
| 252 |
# Create heatmap using matplotlib colormap
|
| 253 |
colormap = cm.get_cmap('jet')
|
| 254 |
cam_colored = colormap(cam_norm)[:, :, :3] # RGB
|
|
@@ -304,7 +307,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
| 304 |
fn=create_tags,
|
| 305 |
inputs=[threshold_slider, sorted_tag_score_state],
|
| 306 |
outputs=[tag_string, label_box],
|
| 307 |
-
show_progress='
|
| 308 |
)
|
| 309 |
|
| 310 |
label_box.select(
|
|
@@ -318,14 +321,14 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
| 318 |
fn=create_cam_visualization_pil,
|
| 319 |
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
| 320 |
outputs=[image_input],
|
| 321 |
-
show_progress='
|
| 322 |
)
|
| 323 |
|
| 324 |
alpha_slider.input(
|
| 325 |
fn=create_cam_visualization_pil,
|
| 326 |
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
| 327 |
outputs=[image_input],
|
| 328 |
-
show_progress='
|
| 329 |
)
|
| 330 |
|
| 331 |
if __name__ == "__main__":
|
|
|
|
| 242 |
if cam is None:
|
| 243 |
return image_pil
|
| 244 |
w, h = image_pil.size
|
| 245 |
+
size = max(w, h)
|
| 246 |
|
| 247 |
# Resize CAM to match image
|
| 248 |
+
cam_resized = np.array(Image.fromarray(cam).resize((size, size), resample=Image.Resampling.BILINEAR))
|
| 249 |
|
| 250 |
# Normalize CAM to [0, 1]
|
| 251 |
cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
|
| 252 |
|
| 253 |
+
cam_norm = transforms.CenterCrop((h, w))(cam_norm)
|
| 254 |
+
|
| 255 |
# Create heatmap using matplotlib colormap
|
| 256 |
colormap = cm.get_cmap('jet')
|
| 257 |
cam_colored = colormap(cam_norm)[:, :, :3] # RGB
|
|
|
|
| 307 |
fn=create_tags,
|
| 308 |
inputs=[threshold_slider, sorted_tag_score_state],
|
| 309 |
outputs=[tag_string, label_box],
|
| 310 |
+
show_progress='hidden'
|
| 311 |
)
|
| 312 |
|
| 313 |
label_box.select(
|
|
|
|
| 321 |
fn=create_cam_visualization_pil,
|
| 322 |
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
| 323 |
outputs=[image_input],
|
| 324 |
+
show_progress='hidden'
|
| 325 |
)
|
| 326 |
|
| 327 |
alpha_slider.input(
|
| 328 |
fn=create_cam_visualization_pil,
|
| 329 |
inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
|
| 330 |
outputs=[image_input],
|
| 331 |
+
show_progress='hidden'
|
| 332 |
)
|
| 333 |
|
| 334 |
if __name__ == "__main__":
|