Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| from PIL import Image | |
| import numpy as np | |
| from captum.attr import LayerGradCam | |
| from captum.attr import visualization as viz | |
| import requests # <-- Import requests | |
| from io import BytesIO # <-- Import BytesIO | |
| # --- 1. Load Model and Processor --- | |
| print("Loading model and processor...") | |
| model_id = "Organika/sdxl-detector" | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32) | |
| model.eval() | |
| print("Model and processor loaded successfully.") | |
| # --- 2. Define the Explainability (Grad-CAM) Function --- | |
| def generate_heatmap(image_tensor, original_image, target_class_index): | |
| target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv | |
| lgc = LayerGradCam(model, target_layer) | |
| baselines = torch.zeros_like(image_tensor) | |
| attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True) | |
| heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0)) | |
| visualized_image, _ = viz.visualize_image_attr( | |
| heatmap, np.array(original_image), method="blended_heat_map", sign="all", show_colorbar=True, title="Model Attention Heatmap", | |
| ) | |
| return visualized_image | |
| # --- 3. MODIFIED Main Prediction Function --- | |
| # Now it accepts two inputs: an uploaded image and a URL string. | |
| def predict(image_upload: Image.Image, image_url: str): | |
| # --- Logic to decide which input to use --- | |
| if image_upload is not None: | |
| input_image = image_upload | |
| print(f"Processing uploaded image of size: {input_image.size}") | |
| elif image_url: | |
| try: | |
| response = requests.get(image_url) | |
| response.raise_for_status() # Raise an exception for bad status codes | |
| input_image = Image.open(BytesIO(response.content)) | |
| print(f"Processing image from URL: {image_url}") | |
| except Exception as e: | |
| raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}") | |
| else: | |
| # If no input is provided, raise an error | |
| raise gr.Error("Please upload an image or provide a URL to analyze.") | |
| if input_image.mode == 'RGBA': | |
| input_image = input_image.convert('RGB') | |
| inputs = processor(images=input_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_class_idx = logits.argmax(-1).item() | |
| confidence_score = probabilities[0][predicted_class_idx].item() | |
| predicted_label = model.config.id2label[predicted_class_idx] | |
| if predicted_label.lower() == 'ai': | |
| explanation = ( | |
| f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n" | |
| "The heatmap on the right highlights the areas that most influenced this decision. " | |
| "Pay close attention if these hotspots are on unnatural-looking features like hair, eyes, skin texture, or strange background details." | |
| ) | |
| else: | |
| explanation = ( | |
| f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n" | |
| "The heatmap shows which areas the model found to be most 'natural'. " | |
| "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly." | |
| ) | |
| print("Generating heatmap...") | |
| heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
| print("Heatmap generated.") | |
| labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))} | |
| return labels_dict, explanation, heatmap_image | |
| # --- 4. MODIFIED Gradio Interface --- | |
| # We use gr.Tabs to create separate input sections. | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # AI Image Detector with Explainability | |
| Determine if an image is AI-generated or human-made. Upload a file or paste a URL. | |
| This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model and provides a **heatmap** to show *why* the model made its decision. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # --- TABS for different input methods --- | |
| with gr.Tabs(): | |
| with gr.TabItem("Upload File"): | |
| input_image_upload = gr.Image(type="pil", label="Upload Your Image") | |
| with gr.TabItem("Use Image URL"): | |
| input_image_url = gr.Textbox(label="Paste Image URL here") | |
| submit_btn = gr.Button("Analyze Image", variant="primary") | |
| with gr.Column(): | |
| output_label = gr.Label(label="Prediction") | |
| output_text = gr.Textbox(label="Explanation", lines=6, interactive=False) | |
| output_heatmap = gr.Image(label="Model Attention Heatmap") | |
| # The click event now passes both possible inputs to the predict function | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[input_image_upload, input_image_url], | |
| outputs=[output_label, output_text, output_heatmap] | |
| ) | |
| # We remove the examples for now to simplify, as they don't work well with a tabbed interface by default. | |
| # If you want them back, you would need a more complex setup to handle which tab the example populates. | |
| # --- 5. Launch the App --- | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |