Spaces:
Running
Running
| """ | |
| Main application file for the Image Evaluator tool. | |
| This module integrates all components and provides a Gradio interface. | |
| """ | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import glob | |
| from PIL import Image | |
| import json | |
| import tempfile | |
| import shutil | |
| from datetime import datetime | |
| # Import custom modules | |
| from modules.metadata_extractor import MetadataExtractor | |
| from modules.technical_metrics import TechnicalMetrics | |
| from modules.aesthetic_metrics import AestheticMetrics | |
| from modules.aggregator import ResultsAggregator | |
| from modules.visualizer import Visualizer | |
| class ImageEvaluator: | |
| """Main class for the Image Evaluator application.""" | |
| def __init__(self): | |
| """Initialize the Image Evaluator.""" | |
| self.results_dir = os.path.join(os.getcwd(), "results") | |
| os.makedirs(self.results_dir, exist_ok=True) | |
| # Initialize components | |
| self.metadata_extractor = MetadataExtractor() | |
| self.technical_metrics = TechnicalMetrics() | |
| self.aesthetic_metrics = AestheticMetrics() | |
| self.aggregator = ResultsAggregator() | |
| self.visualizer = Visualizer(self.results_dir) | |
| # Storage for results | |
| self.evaluation_results = {} | |
| self.metadata_cache = {} | |
| self.current_comparison = None | |
| def process_images(self, image_files, progress=None): | |
| """ | |
| Process a list of image files and extract metadata. | |
| Args: | |
| image_files: list of image file paths | |
| progress: optional gradio Progress object | |
| Returns: | |
| tuple: (metadata_by_model, metadata_by_prompt) | |
| """ | |
| metadata_list = [] | |
| total_files = len(image_files) | |
| for i, img_path in enumerate(image_files): | |
| # Safe progress update without accessing internal attributes | |
| if progress is not None: | |
| try: | |
| progress((i + 1) / total_files, f"Processing image {i+1}/{total_files}") | |
| except Exception as e: | |
| print(f"Progress update error (non-critical): {e}") | |
| # Extract metadata | |
| metadata = self.metadata_extractor.extract_metadata(img_path) | |
| metadata_list.append((img_path, metadata)) | |
| # Cache metadata | |
| self.metadata_cache[img_path] = metadata | |
| # Group by model and prompt | |
| metadata_by_model = self.metadata_extractor.group_images_by_model(metadata_list) | |
| metadata_by_prompt = self.metadata_extractor.group_images_by_prompt(metadata_list) | |
| return metadata_by_model, metadata_by_prompt | |
| def evaluate_images(self, image_files, progress=None): | |
| """ | |
| Evaluate a list of image files using all metrics. | |
| Args: | |
| image_files: list of image file paths | |
| progress: optional gradio Progress object | |
| Returns: | |
| dict: evaluation results by image path | |
| """ | |
| results = {} | |
| total_files = len(image_files) | |
| for i, img_path in enumerate(image_files): | |
| # Safe progress update without accessing internal attributes | |
| if progress is not None: | |
| try: | |
| progress((i + 1) / total_files, f"Evaluating image {i+1}/{total_files}") | |
| except Exception as e: | |
| print(f"Progress update error (non-critical): {e}") | |
| # Get metadata if available | |
| metadata = self.metadata_cache.get(img_path, {}) | |
| prompt = metadata.get('prompt', '') | |
| # Calculate technical metrics | |
| tech_metrics = self.technical_metrics.calculate_all_metrics(img_path) | |
| # Calculate aesthetic metrics | |
| aesthetic_metrics = self.aesthetic_metrics.calculate_all_metrics(img_path, prompt) | |
| # Combine results | |
| combined_metrics = {**tech_metrics, **aesthetic_metrics} | |
| # Store results | |
| results[img_path] = combined_metrics | |
| return results | |
| def compare_models(self, evaluation_results, metadata_by_model): | |
| """ | |
| Compare different models based on evaluation results. | |
| Args: | |
| evaluation_results: dictionary with image paths as keys and metrics as values | |
| metadata_by_model: dictionary with model names as keys and lists of image paths as values | |
| Returns: | |
| tuple: (comparison_df, visualizations) | |
| """ | |
| # Group results by model | |
| results_by_model = {} | |
| for model, image_paths in metadata_by_model.items(): | |
| model_results = [evaluation_results[img] for img in image_paths if img in evaluation_results] | |
| results_by_model[model] = model_results | |
| # Compare models | |
| comparison = self.aggregator.compare_models(results_by_model) | |
| # Create comparison dataframe | |
| comparison_df = self.aggregator.create_comparison_dataframe(comparison) | |
| # Store current comparison | |
| self.current_comparison = comparison_df | |
| # Create visualizations | |
| visualizations = {} | |
| # Create heatmap | |
| heatmap_path = self.visualizer.plot_heatmap(comparison_df) | |
| visualizations['Model Comparison Heatmap'] = heatmap_path | |
| # Create radar chart for key metrics | |
| key_metrics = ['aesthetic_score', 'sharpness', 'noise', 'contrast', 'color_harmony', 'prompt_similarity'] | |
| available_metrics = [m for m in key_metrics if m in comparison_df.columns] | |
| if available_metrics: | |
| radar_path = self.visualizer.plot_radar_chart(comparison_df, available_metrics) | |
| visualizations['Model Comparison Radar Chart'] = radar_path | |
| # Create bar charts for important metrics | |
| for metric in ['overall_score', 'aesthetic_score', 'prompt_similarity']: | |
| if metric in comparison_df.columns: | |
| bar_path = self.visualizer.plot_metric_comparison(comparison_df, metric) | |
| visualizations[f'{metric} Comparison'] = bar_path | |
| return comparison_df, visualizations | |
| def export_results(self, format='csv'): | |
| """ | |
| Export current comparison results. | |
| Args: | |
| format: export format ('csv', 'excel', or 'html') | |
| Returns: | |
| str: path to exported file | |
| """ | |
| if self.current_comparison is not None: | |
| return self.visualizer.export_comparison_table(self.current_comparison, format) | |
| return None | |
| def generate_report(self, comparison_df, visualizations): | |
| """ | |
| Generate a comprehensive HTML report. | |
| Args: | |
| comparison_df: pandas DataFrame with comparison data | |
| visualizations: dictionary of visualization paths | |
| Returns: | |
| str: path to HTML report | |
| """ | |
| metrics_list = comparison_df.columns.tolist() | |
| return self.visualizer.generate_html_report(comparison_df, visualizations, metrics_list) | |
| # Create Gradio interface | |
| def create_interface(): | |
| """Create and configure the Gradio interface.""" | |
| # Initialize evaluator | |
| evaluator = ImageEvaluator() | |
| # Track state | |
| state = { | |
| 'uploaded_images': [], | |
| 'metadata_by_model': {}, | |
| 'metadata_by_prompt': {}, | |
| 'evaluation_results': {}, | |
| 'comparison_df': None, | |
| 'visualizations': {}, | |
| 'report_path': None | |
| } | |
| def upload_images(files): | |
| """Handle image upload and processing.""" | |
| # Reset state | |
| state['uploaded_images'] = [] | |
| state['metadata_by_model'] = {} | |
| state['metadata_by_prompt'] = {} | |
| state['evaluation_results'] = {} | |
| state['comparison_df'] = None | |
| state['visualizations'] = {} | |
| state['report_path'] = None | |
| # Process uploaded files | |
| image_paths = [f.name for f in files] | |
| state['uploaded_images'] = image_paths | |
| # Extract metadata and group images | |
| # Use a simple progress message instead of Gradio Progress object | |
| print("Extracting metadata...") | |
| metadata_by_model, metadata_by_prompt = evaluator.process_images(image_paths) | |
| state['metadata_by_model'] = metadata_by_model | |
| state['metadata_by_prompt'] = metadata_by_prompt | |
| # Create model summary | |
| model_summary = [] | |
| for model, images in metadata_by_model.items(): | |
| model_summary.append(f"- {model}: {len(images)} images") | |
| # Create prompt summary | |
| prompt_summary = [] | |
| for prompt, images in metadata_by_prompt.items(): | |
| prompt_summary.append(f"- {prompt}: {len(images)} images") | |
| return ( | |
| f"Processed {len(image_paths)} images.\n\n" | |
| f"Found {len(metadata_by_model)} models:\n" + "\n".join(model_summary) + "\n\n" | |
| f"Found {len(metadata_by_prompt)} unique prompts." | |
| ) | |
| def evaluate_images(): | |
| """Evaluate all uploaded images.""" | |
| if not state['uploaded_images']: | |
| return "No images uploaded. Please upload images first." | |
| # Evaluate images | |
| # Use a simple progress message instead of Gradio Progress object | |
| print("Evaluating images...") | |
| evaluation_results = evaluator.evaluate_images(state['uploaded_images']) | |
| state['evaluation_results'] = evaluation_results | |
| return f"Evaluated {len(evaluation_results)} images with all metrics." | |
| def compare_models(): | |
| """Compare models based on evaluation results.""" | |
| if not state['evaluation_results'] or not state['metadata_by_model']: | |
| return "No evaluation results available. Please evaluate images first.", None, None | |
| # Compare models | |
| comparison_df, visualizations = evaluator.compare_models( | |
| state['evaluation_results'], state['metadata_by_model'] | |
| ) | |
| state['comparison_df'] = comparison_df | |
| state['visualizations'] = visualizations | |
| # Generate report | |
| report_path = evaluator.generate_report(comparison_df, visualizations) | |
| state['report_path'] = report_path | |
| # Get visualization paths | |
| heatmap_path = visualizations.get('Model Comparison Heatmap') | |
| radar_path = visualizations.get('Model Comparison Radar Chart') | |
| overall_score_path = visualizations.get('overall_score Comparison') | |
| # Convert DataFrame to markdown for display | |
| df_markdown = comparison_df.to_markdown() | |
| return df_markdown, heatmap_path, radar_path | |
| def export_results(format): | |
| """Export results in the specified format.""" | |
| if state['comparison_df'] is None: | |
| return "No comparison results available. Please compare models first." | |
| export_path = evaluator.export_results(format) | |
| if export_path: | |
| return f"Results exported to {export_path}" | |
| else: | |
| return "Failed to export results." | |
| def view_report(): | |
| """View the generated HTML report.""" | |
| if state['report_path'] and os.path.exists(state['report_path']): | |
| return state['report_path'] | |
| else: | |
| return "No report available. Please compare models first." | |
| # Create interface | |
| with gr.Blocks(title="Image Model Evaluator") as interface: | |
| gr.Markdown("# Image Model Evaluator") | |
| gr.Markdown("Upload images generated by different AI models to compare their quality and performance.") | |
| with gr.Tab("Upload & Process"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_input = gr.File( | |
| label="Upload Images (PNG format)", | |
| file_count="multiple", | |
| type="filepath" # Changed from 'file' to 'filepath' | |
| ) | |
| upload_button = gr.Button("Process Uploaded Images") | |
| with gr.Column(): | |
| upload_output = gr.Textbox( | |
| label="Processing Results", | |
| lines=10, | |
| interactive=False | |
| ) | |
| evaluate_button = gr.Button("Evaluate Images") | |
| evaluate_output = gr.Textbox( | |
| label="Evaluation Status", | |
| lines=2, | |
| interactive=False | |
| ) | |
| with gr.Tab("Compare Models"): | |
| compare_button = gr.Button("Compare Models") | |
| with gr.Row(): | |
| comparison_output = gr.Markdown( | |
| label="Comparison Results" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| heatmap_output = gr.Image( | |
| label="Model Comparison Heatmap", | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| radar_output = gr.Image( | |
| label="Model Comparison Radar Chart", | |
| interactive=False | |
| ) | |
| with gr.Tab("Export & Report"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| export_format = gr.Radio( | |
| label="Export Format", | |
| choices=["csv", "excel", "html"], | |
| value="csv" | |
| ) | |
| export_button = gr.Button("Export Results") | |
| export_output = gr.Textbox( | |
| label="Export Status", | |
| lines=2, | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| report_button = gr.Button("View Full Report") | |
| report_output = gr.HTML( | |
| label="Full Report" | |
| ) | |
| # Set up event handlers | |
| upload_button.click( | |
| upload_images, | |
| inputs=[upload_input], | |
| outputs=[upload_output] | |
| ) | |
| evaluate_button.click( | |
| evaluate_images, | |
| inputs=[], | |
| outputs=[evaluate_output] | |
| ) | |
| compare_button.click( | |
| compare_models, | |
| inputs=[], | |
| outputs=[comparison_output, heatmap_output, radar_output] | |
| ) | |
| export_button.click( | |
| export_results, | |
| inputs=[export_format], | |
| outputs=[export_output] | |
| ) | |
| report_button.click( | |
| view_report, | |
| inputs=[], | |
| outputs=[report_output] | |
| ) | |
| return interface | |
| # Launch the application | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| # Remove share=True for HuggingFace Spaces | |
| interface.launch() | |