| import torch | |
| from diffusers import ConfigMixin, Mel, ModelMixin | |
| class ImageEncoder(ModelMixin, ConfigMixin): | |
| def __init__(self, image_processor, encoder_model): | |
| super().__init__() | |
| self.processor = image_processor | |
| self.encoder = encoder_model | |
| self.eval() | |
| def forward(self, x): | |
| x = self.encoder(x) | |
| return x | |
| def encode(self, image): | |
| x = self.processor(image, return_tensors="pt")['pixel_values'] | |
| y = self(x) | |
| y = y.last_hidden_state | |
| embedings = y[:,0,:] | |
| return embedings |