Spaces:
Running
Running
| """ | |
| Aesthetic metrics for image quality assessment using AI models. | |
| These metrics evaluate subjective aspects of images like aesthetic appeal, composition, etc. | |
| """ | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification, CLIPProcessor, CLIPModel | |
| import torchvision.transforms as transforms | |
| class AestheticMetrics: | |
| """Class for computing aesthetic image quality metrics using AI models.""" | |
| def __init__(self): | |
| """Initialize models for aesthetic evaluation.""" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._initialize_models() | |
| def _initialize_models(self): | |
| """Initialize all required models.""" | |
| # Initialize CLIP model for text-image similarity using transformers | |
| try: | |
| self.clip_model_name = "openai/clip-vit-base-patch32" | |
| self.clip_processor = CLIPProcessor.from_pretrained(self.clip_model_name) | |
| self.clip_model = CLIPModel.from_pretrained(self.clip_model_name) | |
| self.clip_model.to(self.device) | |
| self.clip_loaded = True | |
| except Exception as e: | |
| print(f"Warning: Could not load CLIP model: {e}") | |
| self.clip_loaded = False | |
| # Initialize aesthetic predictor model (LAION Aesthetic Predictor v2) | |
| try: | |
| self.aesthetic_model_name = "cafeai/cafe_aesthetic" | |
| self.aesthetic_extractor = AutoFeatureExtractor.from_pretrained(self.aesthetic_model_name) | |
| self.aesthetic_model = AutoModelForImageClassification.from_pretrained(self.aesthetic_model_name) | |
| self.aesthetic_model.to(self.device) | |
| self.aesthetic_loaded = True | |
| except Exception as e: | |
| print(f"Warning: Could not load aesthetic model: {e}") | |
| self.aesthetic_loaded = False | |
| # Initialize transforms for preprocessing | |
| self.transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def calculate_aesthetic_score(self, image_path): | |
| """ | |
| Calculate aesthetic score using a pre-trained model. | |
| Args: | |
| image_path: path to the image file | |
| Returns: | |
| float: aesthetic score between 0 and 10 | |
| """ | |
| if not self.aesthetic_loaded: | |
| return 5.0 # Default middle score if model not loaded | |
| try: | |
| image = Image.open(image_path).convert('RGB') | |
| inputs = self.aesthetic_extractor(images=image, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.aesthetic_model(**inputs) | |
| # Get predicted class probabilities | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| # Calculate weighted score (0-10 scale) | |
| score_weights = torch.tensor([i for i in range(10)]).to(self.device).float() | |
| aesthetic_score = torch.sum(probs * score_weights).item() | |
| return aesthetic_score | |
| except Exception as e: | |
| print(f"Error calculating aesthetic score: {e}") | |
| return 5.0 | |
| def calculate_composition_score(self, image_path): | |
| """ | |
| Estimate composition quality using rule of thirds and symmetry analysis. | |
| Args: | |
| image_path: path to the image file | |
| Returns: | |
| float: composition score between 0 and 10 | |
| """ | |
| try: | |
| # Load image | |
| image = Image.open(image_path).convert('RGB') | |
| img_array = np.array(image) | |
| # Calculate rule of thirds score | |
| h, w = img_array.shape[:2] | |
| third_h, third_w = h // 3, w // 3 | |
| # Define rule of thirds points | |
| thirds_points = [ | |
| (third_w, third_h), (2*third_w, third_h), | |
| (third_w, 2*third_h), (2*third_w, 2*third_h) | |
| ] | |
| # Calculate edge detection to find important elements | |
| gray = np.mean(img_array, axis=2).astype(np.uint8) | |
| edges = np.abs(np.diff(gray, axis=0, append=0)) + np.abs(np.diff(gray, axis=1, append=0)) | |
| # Calculate score based on edge concentration near thirds points | |
| thirds_score = 0 | |
| for px, py in thirds_points: | |
| # Get region around thirds point | |
| region = edges[max(0, py-50):min(h, py+50), max(0, px-50):min(w, px+50)] | |
| thirds_score += np.mean(region) | |
| # Normalize score | |
| thirds_score = min(10, thirds_score / 100) | |
| # Calculate symmetry score | |
| flipped = np.fliplr(img_array) | |
| symmetry_diff = np.mean(np.abs(img_array.astype(float) - flipped.astype(float))) | |
| symmetry_score = 10 * (1 - symmetry_diff / 255) | |
| # Combine scores (weighted average) | |
| composition_score = 0.7 * thirds_score + 0.3 * symmetry_score | |
| return min(10, max(0, composition_score)) | |
| except Exception as e: | |
| print(f"Error calculating composition score: {e}") | |
| return 5.0 | |
| def calculate_color_harmony(self, image_path): | |
| """ | |
| Calculate color harmony score based on color theory. | |
| Args: | |
| image_path: path to the image file | |
| Returns: | |
| float: color harmony score between 0 and 10 | |
| """ | |
| try: | |
| # Load image | |
| image = Image.open(image_path).convert('RGB') | |
| img_array = np.array(image) | |
| # Convert to HSV for better color analysis | |
| hsv = np.array(image.convert('HSV')) | |
| # Extract hue channel and create histogram | |
| hue = hsv[:,:,0].flatten() | |
| hist, _ = np.histogram(hue, bins=36, range=(0, 255)) | |
| hist = hist / np.sum(hist) | |
| # Calculate entropy of hue distribution | |
| entropy = -np.sum(hist * np.log2(hist + 1e-10)) | |
| # Calculate complementary color usage | |
| complementary_score = 0 | |
| for i in range(18): | |
| complementary_i = (i + 18) % 36 | |
| complementary_score += min(hist[i], hist[complementary_i]) | |
| # Calculate analogous color usage | |
| analogous_score = 0 | |
| for i in range(36): | |
| analogous_i1 = (i + 1) % 36 | |
| analogous_i2 = (i + 35) % 36 | |
| analogous_score += min(hist[i], max(hist[analogous_i1], hist[analogous_i2])) | |
| # Calculate saturation variance as a measure of color interest | |
| saturation = hsv[:,:,1].flatten() | |
| saturation_variance = np.var(saturation) | |
| # Combine metrics into final score | |
| harmony_score = ( | |
| 3 * (1 - min(1, entropy/5)) + # Lower entropy is better for harmony | |
| 3 * complementary_score + # Complementary colors | |
| 2 * analogous_score + # Analogous colors | |
| 2 * min(1, saturation_variance/2000) # Saturation variance | |
| ) | |
| return min(10, max(0, harmony_score)) | |
| except Exception as e: | |
| print(f"Error calculating color harmony: {e}") | |
| return 5.0 | |
| def calculate_prompt_similarity(self, image_path, prompt): | |
| """ | |
| Calculate similarity between image and text prompt using CLIP. | |
| Args: | |
| image_path: path to the image file | |
| prompt: text prompt used to generate the image | |
| Returns: | |
| float: similarity score between 0 and 10 | |
| """ | |
| if not self.clip_loaded or not prompt: | |
| return 5.0 # Default middle score if model not loaded or no prompt | |
| try: | |
| # Load image | |
| image = Image.open(image_path).convert('RGB') | |
| # Process inputs with CLIP processor | |
| inputs = self.clip_processor( | |
| text=[prompt], | |
| images=image, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(self.device) | |
| # Calculate similarity | |
| with torch.no_grad(): | |
| outputs = self.clip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| similarity = logits_per_image.item() | |
| # Convert to 0-10 scale (CLIP similarity is typically in 0-100 range) | |
| return min(10, max(0, similarity / 10)) | |
| except Exception as e: | |
| print(f"Error calculating prompt similarity: {e}") | |
| return 5.0 | |
| def calculate_all_metrics(self, image_path, prompt=None): | |
| """ | |
| Calculate all aesthetic metrics for an image. | |
| Args: | |
| image_path: path to the image file | |
| prompt: optional text prompt used to generate the image | |
| Returns: | |
| dict: dictionary with all metric scores | |
| """ | |
| metrics = { | |
| 'aesthetic_score': self.calculate_aesthetic_score(image_path), | |
| 'composition_score': self.calculate_composition_score(image_path), | |
| 'color_harmony': self.calculate_color_harmony(image_path), | |
| } | |
| # Add prompt similarity if prompt is provided | |
| if prompt: | |
| metrics['prompt_similarity'] = self.calculate_prompt_similarity(image_path, prompt) | |
| return metrics | |