Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,8 @@ from torchvision.transforms import transforms
|
|
| 11 |
from torchvision.transforms import InterpolationMode
|
| 12 |
import torchvision.transforms.functional as TF
|
| 13 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class Fit(torch.nn.Module):
|
| 16 |
def __init__(
|
|
@@ -198,7 +200,8 @@ def hook_forward(module, input, output):
|
|
| 198 |
def hook_backward(module, grad_in, grad_out):
|
| 199 |
gradients['value'] = grad_out[0]
|
| 200 |
|
| 201 |
-
def cam_inference(
|
|
|
|
| 202 |
print(f"target_tag: {target_tag}")
|
| 203 |
global input_image, sorted_tag_score, target_tag_index, gradients, activations
|
| 204 |
img = input_image
|
|
@@ -268,7 +271,7 @@ def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
|
|
| 268 |
w, h = image_pil.size
|
| 269 |
|
| 270 |
# Resize CAM to match image
|
| 271 |
-
cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.BILINEAR))
|
| 272 |
|
| 273 |
# Normalize CAM to [0, 1]
|
| 274 |
cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
|
|
@@ -335,7 +338,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
| 335 |
|
| 336 |
label_box.select(
|
| 337 |
fn=cam_inference,
|
| 338 |
-
inputs=[
|
| 339 |
outputs=[image_input]
|
| 340 |
)
|
| 341 |
|
|
|
|
| 11 |
from torchvision.transforms import InterpolationMode
|
| 12 |
import torchvision.transforms.functional as TF
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
+
import numpy as np
|
| 15 |
+
import matplotlib.cm as cm
|
| 16 |
|
| 17 |
class Fit(torch.nn.Module):
|
| 18 |
def __init__(
|
|
|
|
| 200 |
def hook_backward(module, grad_in, grad_out):
|
| 201 |
gradients['value'] = grad_out[0]
|
| 202 |
|
| 203 |
+
def cam_inference(threshold, evt: gr.SelectData):
|
| 204 |
+
target_tag = evt.value
|
| 205 |
print(f"target_tag: {target_tag}")
|
| 206 |
global input_image, sorted_tag_score, target_tag_index, gradients, activations
|
| 207 |
img = input_image
|
|
|
|
| 271 |
w, h = image_pil.size
|
| 272 |
|
| 273 |
# Resize CAM to match image
|
| 274 |
+
cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.Resampling.BILINEAR))
|
| 275 |
|
| 276 |
# Normalize CAM to [0, 1]
|
| 277 |
cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
|
|
|
|
| 338 |
|
| 339 |
label_box.select(
|
| 340 |
fn=cam_inference,
|
| 341 |
+
inputs=[threshold_slider],
|
| 342 |
outputs=[image_input]
|
| 343 |
)
|
| 344 |
|