Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| from PIL import Image | |
| import numpy as np | |
| from captum.attr import LayerGradCam | |
| from captum.attr import visualization as viz | |
| # --- 1. Load Model and Processor --- | |
| # Load the pre-trained model and the image processor from Hugging Face. | |
| # We explicitly set torch_dtype to float32 to ensure CPU compatibility. | |
| print("Loading model and processor...") | |
| model_id = "Organika/sdxl-detector" | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32) | |
| model.eval() # Set the model to evaluation mode | |
| print("Model and processor loaded successfully.") | |
| # --- 2. Define the Explainability (Grad-CAM) Function --- | |
| # This function generates the heatmap showing which parts of the image the model focused on. | |
| def generate_heatmap(image_tensor, original_image, target_class_index): | |
| # LayerGradCam requires a specific layer to hook into. For ConvNeXT models (like this one), | |
| # a good choice is the final layer of the last stage of the encoder. | |
| target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv | |
| # Initialize LayerGradCam | |
| lgc = LayerGradCam(model, target_layer) | |
| # Generate attributions (the "importance" of each pixel) | |
| # The baselines are a reference point, typically a black image. | |
| baselines = torch.zeros_like(image_tensor) | |
| attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True) | |
| # The output of LayerGradCam is a heatmap. We process it for visualization. | |
| # We take the mean across the color channels and format it correctly. | |
| heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0)) | |
| # Use Captum's visualization tool to overlay the heatmap on the original image. | |
| visualized_image, _ = viz.visualize_image_attr( | |
| heatmap, | |
| np.array(original_image), | |
| method="blended_heat_map", | |
| sign="all", | |
| show_colorbar=True, | |
| title="Model Attention Heatmap", | |
| ) | |
| return visualized_image | |
| # --- 3. Define the Main Prediction Function --- | |
| # This function will be called by Gradio every time a user uploads an image. | |
| def predict(input_image: Image.Image): | |
| print(f"Received image of size: {input_image.size}") | |
| # Convert image to RGB if it has an alpha channel | |
| if input_image.mode == 'RGBA': | |
| input_image = input_image.convert('RGB') | |
| # Preprocess the image for the model | |
| inputs = processor(images=input_image, return_tensors="pt") | |
| # Make a prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Convert logits to probabilities | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| # Get the predicted class index and the confidence score | |
| predicted_class_idx = logits.argmax(-1).item() | |
| confidence_score = probabilities[0][predicted_class_idx].item() | |
| # Get the label name (e.g., 'ai' or 'human') | |
| predicted_label = model.config.id2label[predicted_class_idx] | |
| # --- Generate Human-Readable Explanation --- | |
| # This directly answers your requirement to "say out which one is less human". | |
| if predicted_label.lower() == 'ai': | |
| explanation = ( | |
| f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n" | |
| "The heatmap on the right highlights the areas that most influenced this decision. " | |
| "According to your research, pay close attention if these hotspots are on " | |
| "unnatural-looking features like hair, eyes, skin texture, or strange background details." | |
| ) | |
| else: | |
| explanation = ( | |
| f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n" | |
| "The heatmap shows which areas the model found to be most 'natural'. " | |
| "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly." | |
| ) | |
| # --- Generate the Heatmap --- | |
| # We call our Grad-CAM function to create the visualization. | |
| print("Generating heatmap...") | |
| heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
| print("Heatmap generated.") | |
| # Return the classification labels, the text explanation, and the heatmap image | |
| # The labels dictionary is for the gr.Label component. | |
| labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))} | |
| return labels_dict, explanation, heatmap_image | |
| # --- 4. Create the Gradio Interface --- | |
| # This sets up the web UI with inputs and outputs. | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # AI Image Detector with Explainability | |
| Upload an image to determine if it was generated by AI or created by a human. | |
| This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model. | |
| In addition to the prediction, it provides a **heatmap** to show *why* the model made its decision, highlighting the areas it found most suspicious or authentic. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| submit_btn = gr.Button("Analyze Image", variant="primary") | |
| with gr.Column(): | |
| output_label = gr.Label(label="Prediction") | |
| output_text = gr.Textbox(label="Explanation", lines=6) | |
| output_heatmap = gr.Image(label="Model Attention Heatmap") | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=input_image, | |
| outputs=[output_label, output_text, output_heatmap] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/ai_example_1.png"], | |
| ["examples/human_example_1.jpg"], | |
| ["examples/ai_example_2.png"], | |
| ], | |
| inputs=input_image, | |
| outputs=[output_label, output_text, output_heatmap], | |
| fn=predict, | |
| cache_examples=True, # Speeds up demo loading | |
| # Add this line to grant permission for local files | |
| allow_file_access=True | |
| ) | |
| # --- Create example files for the demo --- | |
| import os | |
| from urllib.request import urlretrieve | |
| print("Creating examples directory and downloading example images...") | |
| os.makedirs("examples", exist_ok=True) | |
| # These URLs are from the stable Hugging Face documentation assets | |
| try: | |
| # AI Example 1: A classic AI-generated image (astronaut on a horse) | |
| urlretrieve( | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/horse.png", | |
| "examples/ai_example_1.png" | |
| ) | |
| # Human Example 1: A real photograph | |
| urlretrieve( | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guide/zookeeper.png", | |
| "examples/human_example_1.jpg" | |
| ) | |
| # AI Example 2: An AI-generated portrait, good for testing face/hair detection | |
| urlretrieve( | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/stable-diffusion-sdxl/sdxl-gfpgan-output.png", | |
| "examples/ai_example_2.png" | |
| ) | |
| print("Example images downloaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to download example images: {e}") | |
| # --- 5. Launch the App --- | |
| # This line was already there, just make sure it's the last part of your script | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |