Spaces:
Running
Running
| """ | |
| Module for visualizing image evaluation results and creating comparison tables. | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from matplotlib.colors import LinearSegmentedColormap | |
| import os | |
| import io | |
| from PIL import Image | |
| import base64 | |
| class Visualizer: | |
| """Class for visualizing image evaluation results.""" | |
| def __init__(self, output_dir='./results'): | |
| """ | |
| Initialize visualizer with output directory. | |
| Args: | |
| output_dir: directory to save visualization results | |
| """ | |
| self.output_dir = output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Set up color schemes | |
| self.setup_colors() | |
| def setup_colors(self): | |
| """Set up color schemes for visualizations.""" | |
| # Custom colormap for heatmaps | |
| self.cmap = LinearSegmentedColormap.from_list( | |
| 'custom_cmap', ['#FF5E5B', '#FFED66', '#00CEFF', '#0089BA', '#008F7A'], N=256 | |
| ) | |
| # Color palette for bar charts | |
| self.palette = sns.color_palette("viridis", 10) | |
| # Set Seaborn style | |
| sns.set_style("whitegrid") | |
| def create_comparison_table(self, results_dict, metrics_list=None): | |
| """ | |
| Create a comparison table from evaluation results. | |
| Args: | |
| results_dict: dictionary with model names as keys and evaluation results as values | |
| metrics_list: list of metrics to include in the table (if None, include all) | |
| Returns: | |
| pandas.DataFrame: comparison table | |
| """ | |
| # Initialize empty dataframe | |
| df = pd.DataFrame() | |
| # Process each model's results | |
| for model_name, model_results in results_dict.items(): | |
| # Create a row for this model | |
| model_row = {'Model': model_name} | |
| # Add metrics to the row | |
| for metric_name, metric_value in model_results.items(): | |
| if metrics_list is None or metric_name in metrics_list: | |
| # Format numeric values to 2 decimal places | |
| if isinstance(metric_value, (int, float)): | |
| model_row[metric_name] = round(metric_value, 2) | |
| else: | |
| model_row[metric_name] = metric_value | |
| # Append to dataframe | |
| df = pd.concat([df, pd.DataFrame([model_row])], ignore_index=True) | |
| # Set Model as index | |
| if not df.empty: | |
| df.set_index('Model', inplace=True) | |
| return df | |
| def plot_metric_comparison(self, df, metric_name, title=None, figsize=(10, 6)): | |
| """ | |
| Create a bar chart comparing models on a specific metric. | |
| Args: | |
| df: pandas DataFrame with comparison data | |
| metric_name: name of the metric to plot | |
| title: optional custom title | |
| figsize: figure size as (width, height) | |
| Returns: | |
| str: path to saved figure | |
| """ | |
| if metric_name not in df.columns: | |
| raise ValueError(f"Metric '{metric_name}' not found in dataframe") | |
| # Create figure | |
| plt.figure(figsize=figsize) | |
| # Create bar chart | |
| ax = sns.barplot(x=df.index, y=df[metric_name], palette=self.palette) | |
| # Set title and labels | |
| if title: | |
| plt.title(title, fontsize=14) | |
| else: | |
| plt.title(f"Model Comparison: {metric_name}", fontsize=14) | |
| plt.xlabel("Model", fontsize=12) | |
| plt.ylabel(metric_name, fontsize=12) | |
| # Rotate x-axis labels for better readability | |
| plt.xticks(rotation=45, ha='right') | |
| # Add value labels on top of bars | |
| for i, v in enumerate(df[metric_name]): | |
| ax.text(i, v + 0.1, str(round(v, 2)), ha='center') | |
| plt.tight_layout() | |
| # Save figure | |
| output_path = os.path.join(self.output_dir, f"{metric_name}_comparison.png") | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return output_path | |
| def plot_radar_chart(self, df, metrics_list, title=None, figsize=(10, 8)): | |
| """ | |
| Create a radar chart comparing models across multiple metrics. | |
| Args: | |
| df: pandas DataFrame with comparison data | |
| metrics_list: list of metrics to include in the radar chart | |
| title: optional custom title | |
| figsize: figure size as (width, height) | |
| Returns: | |
| str: path to saved figure | |
| """ | |
| # Filter metrics that exist in the dataframe | |
| metrics = [m for m in metrics_list if m in df.columns] | |
| if not metrics: | |
| raise ValueError("None of the specified metrics found in dataframe") | |
| # Number of metrics | |
| N = len(metrics) | |
| # Create figure | |
| fig = plt.figure(figsize=figsize) | |
| ax = fig.add_subplot(111, polar=True) | |
| # Compute angle for each metric | |
| angles = [n / float(N) * 2 * np.pi for n in range(N)] | |
| angles += angles[:1] # Close the loop | |
| # Plot each model | |
| for i, model in enumerate(df.index): | |
| values = df.loc[model, metrics].values.flatten().tolist() | |
| values += values[:1] # Close the loop | |
| # Plot values | |
| ax.plot(angles, values, linewidth=2, linestyle='solid', label=model, color=self.palette[i % len(self.palette)]) | |
| ax.fill(angles, values, alpha=0.1, color=self.palette[i % len(self.palette)]) | |
| # Set labels | |
| plt.xticks(angles[:-1], metrics, size=12) | |
| # Set y-axis limits | |
| ax.set_ylim(0, 10) | |
| # Add legend | |
| plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) | |
| # Set title | |
| if title: | |
| plt.title(title, size=16, y=1.1) | |
| else: | |
| plt.title("Model Comparison Across Metrics", size=16, y=1.1) | |
| # Save figure | |
| output_path = os.path.join(self.output_dir, "radar_comparison.png") | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return output_path | |
| def plot_heatmap(self, df, title=None, figsize=(12, 8)): | |
| """ | |
| Create a heatmap of all metrics across models. | |
| Args: | |
| df: pandas DataFrame with comparison data | |
| title: optional custom title | |
| figsize: figure size as (width, height) | |
| Returns: | |
| str: path to saved figure | |
| """ | |
| # Create figure | |
| plt.figure(figsize=figsize) | |
| # Create heatmap | |
| ax = sns.heatmap(df, annot=True, cmap=self.cmap, fmt=".2f", linewidths=.5) | |
| # Set title | |
| if title: | |
| plt.title(title, fontsize=16) | |
| else: | |
| plt.title("Model Comparison Heatmap", fontsize=16) | |
| plt.tight_layout() | |
| # Save figure | |
| output_path = os.path.join(self.output_dir, "comparison_heatmap.png") | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return output_path | |
| def plot_prompt_performance(self, prompt_results, metric_name, top_n=5, figsize=(12, 8)): | |
| """ | |
| Create a grouped bar chart showing model performance on different prompts. | |
| Args: | |
| prompt_results: dictionary with prompts as keys and model results as values | |
| metric_name: name of the metric to plot | |
| top_n: number of top prompts to include | |
| figsize: figure size as (width, height) | |
| Returns: | |
| str: path to saved figure | |
| """ | |
| # Create dataframe from results | |
| data = [] | |
| for prompt, models_data in prompt_results.items(): | |
| for model, metrics in models_data.items(): | |
| if metric_name in metrics: | |
| data.append({ | |
| 'Prompt': prompt, | |
| 'Model': model, | |
| metric_name: metrics[metric_name] | |
| }) | |
| df = pd.DataFrame(data) | |
| if df.empty: | |
| raise ValueError(f"No data found for metric '{metric_name}'") | |
| # Get top N prompts by average metric value | |
| top_prompts = df.groupby('Prompt')[metric_name].mean().nlargest(top_n).index.tolist() | |
| df_filtered = df[df['Prompt'].isin(top_prompts)] | |
| # Create figure | |
| plt.figure(figsize=figsize) | |
| # Create grouped bar chart | |
| ax = sns.barplot(x='Prompt', y=metric_name, hue='Model', data=df_filtered, palette=self.palette) | |
| # Set title and labels | |
| plt.title(f"Model Performance by Prompt: {metric_name}", fontsize=14) | |
| plt.xlabel("Prompt", fontsize=12) | |
| plt.ylabel(metric_name, fontsize=12) | |
| # Rotate x-axis labels for better readability | |
| plt.xticks(rotation=45, ha='right') | |
| # Adjust legend | |
| plt.legend(title="Model", bbox_to_anchor=(1.05, 1), loc='upper left') | |
| plt.tight_layout() | |
| # Save figure | |
| output_path = os.path.join(self.output_dir, f"prompt_performance_{metric_name}.png") | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return output_path | |
| def create_image_grid(self, image_paths, titles=None, cols=3, figsize=(15, 15)): | |
| """ | |
| Create a grid of images for visual comparison. | |
| Args: | |
| image_paths: list of paths to images | |
| titles: optional list of titles for each image | |
| cols: number of columns in the grid | |
| figsize: figure size as (width, height) | |
| Returns: | |
| str: path to saved figure | |
| """ | |
| # Calculate number of rows needed | |
| rows = (len(image_paths) + cols - 1) // cols | |
| # Create figure | |
| fig, axes = plt.subplots(rows, cols, figsize=figsize) | |
| axes = axes.flatten() | |
| # Add each image to the grid | |
| for i, img_path in enumerate(image_paths): | |
| if i < len(axes): | |
| try: | |
| img = Image.open(img_path) | |
| axes[i].imshow(np.array(img)) | |
| # Add title if provided | |
| if titles and i < len(titles): | |
| axes[i].set_title(titles[i]) | |
| # Remove axis ticks | |
| axes[i].set_xticks([]) | |
| axes[i].set_yticks([]) | |
| except Exception as e: | |
| print(f"Error loading image {img_path}: {e}") | |
| axes[i].text(0.5, 0.5, f"Error loading image", ha='center', va='center') | |
| axes[i].set_xticks([]) | |
| axes[i].set_yticks([]) | |
| # Hide unused subplots | |
| for j in range(len(image_paths), len(axes)): | |
| axes[j].axis('off') | |
| plt.tight_layout() | |
| # Save figure | |
| output_path = os.path.join(self.output_dir, "image_comparison_grid.png") | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return output_path | |
| def export_comparison_table(self, df, format='csv'): | |
| """ | |
| Export comparison table to file. | |
| Args: | |
| df: pandas DataFrame with comparison data | |
| format: export format ('csv', 'excel', or 'html') | |
| Returns: | |
| str: path to saved file | |
| """ | |
| if format == 'csv': | |
| output_path = os.path.join(self.output_dir, "comparison_table.csv") | |
| df.to_csv(output_path) | |
| elif format == 'excel': | |
| output_path = os.path.join(self.output_dir, "comparison_table.xlsx") | |
| df.to_excel(output_path) | |
| elif format == 'html': | |
| output_path = os.path.join(self.output_dir, "comparison_table.html") | |
| df.to_html(output_path) | |
| else: | |
| raise ValueError(f"Unsupported format: {format}") | |
| return output_path | |
| def generate_html_report(self, comparison_table, image_paths, metrics_list): | |
| """ | |
| Generate a comprehensive HTML report with all visualizations. | |
| Args: | |
| comparison_table: pandas DataFrame with comparison data | |
| image_paths: dictionary of generated visualization image paths | |
| metrics_list: list of metrics included in the analysis | |
| Returns: | |
| str: path to saved HTML report | |
| """ | |
| # Create HTML content | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Image Model Evaluation Report</title> | |
| <style> | |
| body {{ | |
| font-family: Arial, sans-serif; | |
| line-height: 1.6; | |
| margin: 0; | |
| padding: 20px; | |
| color: #333; | |
| }} | |
| h1, h2, h3 {{ | |
| color: #2c3e50; | |
| }} | |
| .container {{ | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| }} | |
| table {{ | |
| border-collapse: collapse; | |
| width: 100%; | |
| margin-bottom: 20px; | |
| }} | |
| th, td {{ | |
| border: 1px solid #ddd; | |
| padding: 8px; | |
| text-align: left; | |
| }} | |
| th {{ | |
| background-color: #f2f2f2; | |
| }} | |
| tr:nth-child(even) {{ | |
| background-color: #f9f9f9; | |
| }} | |
| .visualization {{ | |
| margin: 20px 0; | |
| text-align: center; | |
| }} | |
| .visualization img {{ | |
| max-width: 100%; | |
| height: auto; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| }} | |
| .metrics-list {{ | |
| background-color: #f8f9fa; | |
| padding: 15px; | |
| border-radius: 5px; | |
| margin-bottom: 20px; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>Image Model Evaluation Report</h1> | |
| <h2>Metrics Overview</h2> | |
| <div class="metrics-list"> | |
| <h3>Metrics included in this analysis:</h3> | |
| <ul> | |
| """ | |
| # Add metrics list | |
| for metric in metrics_list: | |
| html_content += f" <li><strong>{metric}</strong></li>\n" | |
| html_content += """ | |
| </ul> | |
| </div> | |
| <h2>Comparison Table</h2> | |
| """ | |
| # Add comparison table | |
| html_content += comparison_table.to_html(classes="table table-striped") | |
| # Add visualizations | |
| html_content += """ | |
| <h2>Visualizations</h2> | |
| """ | |
| for title, img_path in image_paths.items(): | |
| if os.path.exists(img_path): | |
| # Convert image to base64 for embedding | |
| with open(img_path, "rb") as img_file: | |
| img_data = base64.b64encode(img_file.read()).decode('utf-8') | |
| html_content += f""" | |
| <div class="visualization"> | |
| <h3>{title}</h3> | |
| <img src="data:image/png;base64,{img_data}" alt="{title}"> | |
| </div> | |
| """ | |
| # Close HTML | |
| html_content += """ | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| # Save HTML report | |
| output_path = os.path.join(self.output_dir, "evaluation_report.html") | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| f.write(html_content) | |
| return output_path | |