Spaces:
Running
Running
| import os | |
| import tempfile | |
| from typing import Dict | |
| import ImageReward as RM | |
| import torch | |
| from PIL import Image | |
| class ImageRewardMetric: | |
| def __init__(self): | |
| self.device = torch.device( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" | |
| if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| self.model = RM.load("ImageReward-v1.0", device=str(self.device)) | |
| def name(self) -> str: | |
| return "image_reward" | |
| def compute_score( | |
| self, | |
| image: Image.Image, | |
| prompt: str, | |
| ) -> Dict[str, float]: | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| image.save(tmp.name) | |
| score = self.model.score(prompt, [tmp.name]) | |
| os.unlink(tmp.name) | |
| return {"image_reward": score} | |