Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F # <-- ADD THIS IMPORT | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| from PIL import Image | |
| import numpy as np | |
| from captum.attr import LayerGradCam | |
| from captum.attr import visualization as viz | |
| import requests | |
| from io import BytesIO | |
| # --- 1. Load Model and Processor --- | |
| print("Loading model and processor...") | |
| model_id = "Organika/sdxl-detector" | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32) | |
| model.eval() | |
| print("Model and processor loaded successfully.") | |
| # --- 2. FINAL, CORRECTED Explainability (Grad-CAM) Function --- | |
| def generate_heatmap(image_tensor, original_image, target_class_index): | |
| # This wrapper is correct and necessary for Captum to work with Hugging Face models. | |
| def model_forward_wrapper(input_tensor): | |
| outputs = model(pixel_values=input_tensor) | |
| return outputs.logits | |
| # The target layer is also correct for the Swin Transformer. | |
| target_layer = model.swin.layernorm | |
| # Initialize LayerGradCam with the wrapper and the target layer. | |
| lgc = LayerGradCam(model_forward_wrapper, target_layer) | |
| # This call now works and returns the attributions. | |
| attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True) | |
| # --- THIS IS THE FIX for the Transformer Architecture --- | |
| # Transformer models output a sequence of patch attributions, not a 2D grid. | |
| # We must reshape this sequence into a grid and then upsample it. | |
| # 1. Determine the grid size (e.g., for 49 patches, it's 7x7) | |
| # We remove the batch dimension, and get the number of patches (sequence length). | |
| num_patches = attributions.shape[-1] | |
| grid_size = int(np.sqrt(num_patches)) | |
| # 2. Reshape the 1D attributions into a 2D grid. | |
| heatmap = attributions.squeeze(0).squeeze(0).reshape(grid_size, grid_size) | |
| # 3. Upsample the small heatmap to match the original image size for overlay. | |
| # We need to add batch and channel dimensions back for the interpolate function. | |
| heatmap = heatmap.unsqueeze(0).unsqueeze(0) | |
| # Note: original_image.size is (W, H), interpolate needs size as (H, W) | |
| upsampled_heatmap = F.interpolate(heatmap, size=original_image.size[::-1], mode='bilinear', align_corners=False) | |
| # 4. Prepare the final heatmap for visualization | |
| heatmap_for_viz = upsampled_heatmap.squeeze().cpu().detach().numpy() | |
| # The visualization function expects a (H, W, C) shaped numpy array. | |
| # Our heatmap is (H, W), so we add a channel dimension. | |
| visualized_image, _ = viz.visualize_image_attr( | |
| np.expand_dims(heatmap_for_viz, axis=-1), | |
| np.array(original_image), | |
| method="blended_heat_map", | |
| sign="all", | |
| show_colorbar=True, | |
| title="Model Attention Heatmap", | |
| ) | |
| return visualized_image | |
| # --- 3. Main Prediction Function (Unchanged) --- | |
| def predict(image_upload: Image.Image, image_url: str): | |
| if image_upload is not None: | |
| input_image = image_upload | |
| print(f"Processing uploaded image of size: {input_image.size}") | |
| elif image_url: | |
| try: | |
| response = requests.get(image_url) | |
| response.raise_for_status() | |
| input_image = Image.open(BytesIO(response.content)) | |
| print(f"Processing image from URL: {image_url}") | |
| except Exception as e: | |
| raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}") | |
| else: | |
| raise gr.Error("Please upload an image or provide a URL to analyze.") | |
| if input_image.mode == 'RGBA': | |
| input_image = input_image.convert('RGB') | |
| inputs = processor(images=input_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_class_idx = logits.argmax(-1).item() | |
| confidence_score = probabilities[0][predicted_class_idx].item() | |
| predicted_label = model.config.id2label[predicted_class_idx] | |
| if predicted_label.lower() == 'ai': | |
| explanation = ( | |
| f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n" | |
| "The heatmap on the right highlights the areas that most influenced this decision. " | |
| "Pay close attention if these hotspots are on unnatural-looking features like hair, eyes, skin texture, or strange background details." | |
| ) | |
| else: | |
| explanation = ( | |
| f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n" | |
| "The heatmap shows which areas the model found to be most 'natural'. " | |
| "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly." | |
| ) | |
| print("Generating heatmap...") | |
| heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
| print("Heatmap generated.") | |
| labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))} | |
| return labels_dict, explanation, heatmap_image | |
| # --- 4. Gradio Interface (Unchanged) --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # AI Image Detector with Explainability | |
| Determine if an image is AI-generated or human-made. Upload a file or paste a URL. | |
| This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model and provides a **heatmap** to show *why* the model made its decision. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.TabItem("Upload File"): | |
| input_image_upload = gr.Image(type="pil", label="Upload Your Image") | |
| with gr.TabItem("Use Image URL"): | |
| input_image_url = gr.Textbox(label="Paste Image URL here") | |
| submit_btn = gr.Button("Analyze Image", variant="primary") | |
| with gr.Column(): | |
| output_label = gr.Label(label="Prediction") | |
| output_text = gr.Textbox(label="Explanation", lines=6, interactive=False) | |
| output_heatmap = gr.Image(label="Model Attention Heatmap") | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[input_image_upload, input_image_url], | |
| outputs=[output_label, output_text, output_heatmap] | |
| ) | |
| # --- 5. Launch the App (Unchanged) --- | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |