Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModel | |
| from rewards.base_reward import BaseRewardLoss | |
| class PickScoreLoss(BaseRewardLoss): | |
| """PickScore reward loss function for optimization.""" | |
| def __init__( | |
| self, | |
| weighting: float, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| cache_dir: str, | |
| tokenizer, | |
| memsave: bool = False, | |
| ): | |
| self.tokenizer = tokenizer | |
| self.pickscore_model = AutoModel.from_pretrained( | |
| "yuvalkirstain/PickScore_v1", cache_dir=cache_dir | |
| ).eval() | |
| if memsave: | |
| import memsave_torch.nn | |
| self.pickscore_model = memsave_torch.nn.convert_to_memory_saving( | |
| self.pickscore_model | |
| ) | |
| self.pickscore_model = self.pickscore_model.to(device, dtype=dtype) | |
| self.freeze_parameters(self.pickscore_model.parameters()) | |
| super().__init__("PickScore", weighting) | |
| self.pickscore_model._set_gradient_checkpointing(True) | |
| def get_image_features(self, image) -> torch.Tensor: | |
| reward_img_features = self.pickscore_model.get_image_features(image) | |
| return reward_img_features | |
| def get_text_features(self, prompt: str) -> torch.Tensor: | |
| prompt_token = self.tokenizer( | |
| prompt, return_tensors="pt", padding=True, max_length=77, truncation=True | |
| ).to("cuda") | |
| reward_text_features = self.pickscore_model.get_text_features(**prompt_token) | |
| return reward_text_features | |
| def compute_loss( | |
| self, image_features: torch.Tensor, text_features: torch.Tensor | |
| ) -> torch.Tensor: | |
| pickscore_loss = ( | |
| 30 | |
| - ( | |
| self.pickscore_model.logit_scale.exp() | |
| * (image_features @ text_features.T) | |
| ).mean() | |
| ) | |
| return pickscore_loss | |