Spaces:
Sleeping
Sleeping
| import ImageReward as RM | |
| import torch | |
| from rewards.base_reward import BaseRewardLoss | |
| class ImageRewardLoss: | |
| """Image reward loss for optimization.""" | |
| def __init__( | |
| self, | |
| weighting: float, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| cache_dir: str, | |
| memsave: bool = False, | |
| ): | |
| self.name = "ImageReward" | |
| self.weighting = weighting | |
| self.dtype = dtype | |
| self.imagereward_model = RM.load("ImageReward-v1.0", download_root=cache_dir) | |
| self.imagereward_model = self.imagereward_model.to( | |
| device=device, dtype=self.dtype | |
| ) | |
| self.imagereward_model.eval() | |
| BaseRewardLoss.freeze_parameters(self.imagereward_model.parameters()) | |
| def __call__(self, image: torch.Tensor, prompt: str) -> torch.Tensor: | |
| imagereward_score = self.score_diff(prompt, image) | |
| return (2 - imagereward_score).mean() | |
| def score_diff(self, prompt, image): | |
| # text encode | |
| text_input = self.imagereward_model.blip.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=35, | |
| return_tensors="pt", | |
| ).to(self.imagereward_model.device) | |
| image_embeds = self.imagereward_model.blip.visual_encoder(image) | |
| # text encode cross attention with image | |
| image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
| self.imagereward_model.device | |
| ) | |
| text_output = self.imagereward_model.blip.text_encoder( | |
| text_input.input_ids, | |
| attention_mask=text_input.attention_mask, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_atts, | |
| return_dict=True, | |
| ) | |
| txt_features = text_output.last_hidden_state[:, 0, :].to( | |
| self.imagereward_model.device, dtype=self.dtype | |
| ) | |
| rewards = self.imagereward_model.mlp(txt_features) | |
| rewards = (rewards - self.imagereward_model.mean) / self.imagereward_model.std | |
| return rewards | |