Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,17 +16,25 @@ model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=to
|
|
| 16 |
model.eval()
|
| 17 |
print("Model and processor loaded successfully.")
|
| 18 |
|
| 19 |
-
# --- 2.
|
| 20 |
def generate_heatmap(image_tensor, original_image, target_class_index):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# This part is correct from our last fix.
|
| 22 |
target_layer = model.swin.layernorm
|
| 23 |
|
| 24 |
-
# Initialize LayerGradCam
|
| 25 |
-
|
|
|
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
# The 'baselines' argument is not used by LayerGradCam, so we remove it.
|
| 29 |
-
# The call is now simpler and correct for this specific method.
|
| 30 |
attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
|
| 31 |
|
| 32 |
# The rest of the function remains the same.
|
|
|
|
| 16 |
model.eval()
|
| 17 |
print("Model and processor loaded successfully.")
|
| 18 |
|
| 19 |
+
# --- 2. Define the Explainability (Grad-CAM) Function ---
|
| 20 |
def generate_heatmap(image_tensor, original_image, target_class_index):
|
| 21 |
+
|
| 22 |
+
# --- THIS IS THE FIX ---
|
| 23 |
+
# We define a wrapper function that ensures our model returns a simple tensor,
|
| 24 |
+
# which is what Captum expects. It takes the model's output object and
|
| 25 |
+
# extracts the 'logits' tensor from it.
|
| 26 |
+
def model_forward_wrapper(input_tensor):
|
| 27 |
+
outputs = model(pixel_values=input_tensor)
|
| 28 |
+
return outputs.logits
|
| 29 |
+
|
| 30 |
# This part is correct from our last fix.
|
| 31 |
target_layer = model.swin.layernorm
|
| 32 |
|
| 33 |
+
# Initialize LayerGradCam, but pass our new wrapper function instead of the raw model.
|
| 34 |
+
# Captum will now use this wrapper to get the model's output.
|
| 35 |
+
lgc = LayerGradCam(model_forward_wrapper, target_layer)
|
| 36 |
|
| 37 |
+
# This call now works because `lgc` gets a proper tensor from our wrapper.
|
|
|
|
|
|
|
| 38 |
attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
|
| 39 |
|
| 40 |
# The rest of the function remains the same.
|