Spaces:
Running
Running
| import os | |
| from typing import Dict | |
| import huggingface_hub | |
| import torch | |
| from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer | |
| from hpsv2.utils import hps_version_map, root_path | |
| from PIL import Image | |
| class HPSMetric: | |
| def __init__(self): | |
| self.hps_version = "v2.1" | |
| self.device = torch.device( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" | |
| if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| self.model_dict = {} | |
| self._initialize_model() | |
| def _initialize_model(self): | |
| if not self.model_dict: | |
| model, preprocess_train, preprocess_val = create_model_and_transforms( | |
| "ViT-H-14", | |
| "laion2B-s32B-b79K", | |
| precision="amp", | |
| device=self.device, | |
| jit=False, | |
| force_quick_gelu=False, | |
| force_custom_text=False, | |
| force_patch_dropout=False, | |
| force_image_size=None, | |
| pretrained_image=False, | |
| image_mean=None, | |
| image_std=None, | |
| light_augmentation=True, | |
| aug_cfg={}, | |
| output_dict=True, | |
| with_score_predictor=False, | |
| with_region_predictor=False, | |
| ) | |
| self.model_dict["model"] = model | |
| self.model_dict["preprocess_val"] = preprocess_val | |
| # Load checkpoint | |
| if not os.path.exists(root_path): | |
| os.makedirs(root_path) | |
| cp = huggingface_hub.hf_hub_download( | |
| "xswu/HPSv2", hps_version_map[self.hps_version] | |
| ) | |
| checkpoint = torch.load(cp, map_location=self.device) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| self.tokenizer = get_tokenizer("ViT-H-14") | |
| model = model.to(self.device) | |
| model.eval() | |
| def name(self) -> str: | |
| return "hps" | |
| def compute_score( | |
| self, | |
| image: Image.Image, | |
| prompt: str, | |
| ) -> Dict[str, float]: | |
| model = self.model_dict["model"] | |
| preprocess_val = self.model_dict["preprocess_val"] | |
| with torch.no_grad(): | |
| # Process the image | |
| image_tensor = ( | |
| preprocess_val(image) | |
| .unsqueeze(0) | |
| .to(device=self.device, non_blocking=True) | |
| ) | |
| # Process the prompt | |
| text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True) | |
| # Calculate the HPS | |
| with torch.cuda.amp.autocast(): | |
| outputs = model(image_tensor, text) | |
| image_features, text_features = ( | |
| outputs["image_features"], | |
| outputs["text_features"], | |
| ) | |
| logits_per_image = image_features @ text_features.T | |
| hps_score = torch.diagonal(logits_per_image).cpu().numpy() | |
| return {"hps": float(hps_score[0])} | |