from PIL import Image import einops import numpy as np import torch from hydra.utils import instantiate from lightly.models import utils # https://docs.lightly.ai/self-supervised-learning/examples/mae.html from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM from timm.models.vision_transformer import VisionTransformer from huggingface_hub import PyTorchModelHubMixin class MAE(torch.nn.Module, PyTorchModelHubMixin): def __init__(self, cfg): super().__init__() vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size) self.patch_size = vit.patch_embed.patch_size[0] # Get MAE backbone self.backbone = MaskedVisionTransformerTIMM(vit=vit) self.sequence_length = self.backbone.sequence_length self.encoder_dim = vit.embed_dim # for convenience later # Get decoder self.decoder = MAEDecoderTIMM( num_patches=vit.patch_embed.num_patches, patch_size=self.patch_size, embed_dim=vit.embed_dim, decoder_embed_dim=cfg.ssl_model.decoder.embed_dim, decoder_depth=cfg.ssl_model.decoder.depth, decoder_num_heads=cfg.ssl_model.decoder.num_heads, mlp_ratio=cfg.ssl_model.decoder.mlp_ratio, proj_drop_rate=cfg.ssl_model.decoder.dropout, attn_drop_rate=cfg.ssl_model.decoder.attention_dropout, ) self.mask_ratio = cfg.ssl_model.mask_ratio # saved as model parameter, not aug, since it is applied within model self.criterion = torch.nn.MSELoss() def forward_encoder(self, images, idx_keep=None): return self.backbone.encode(images=images, idx_keep=idx_keep) def forward_decoder(self, x_encoded, idx_keep, idx_mask): # build decoder input batch_size = x_encoded.shape[0] x_decode = self.decoder.embed(x_encoded) x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length)) x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) # decoder forward pass x_decoded = self.decoder.decode(x_masked) # predict pixel values for masked tokens x_pred = utils.get_at_index(x_decoded, idx_mask) x_pred = self.decoder.predict(x_pred) return x_pred def training_step(self, batch, batch_idx): images = batch["image"] # views contains only a single view batch_size = images.shape[0] idx_keep, idx_mask = utils.random_token_mask( size=(batch_size, self.sequence_length), mask_ratio=self.mask_ratio, device=images.device, ) x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) # decode and calculate loss (encoder no longer directly used) x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) # get image patches for masked tokens patches = utils.patchify(images, self.patch_size) # must adjust idx_mask for missing class token # (class token was added after calculating which indices to mask, # so we need to subtract 1 from idx_mask to get the new indices that are masked) target = utils.get_at_index(patches, idx_mask - 1) loss = self.criterion(x_pred, target) return loss, x_encoded def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch["image"] # views contains only a single view batch_size = images.shape[0] idx_keep, idx_mask = utils.random_token_mask( size=(batch_size, self.sequence_length), mask_ratio=self.mask_ratio, device=images.device, ) x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) # get image patches for masked tokens patches = utils.patchify(images, self.patch_size) # must adjust idx_mask for missing class token target = utils.get_at_index(patches, idx_mask - 1) loss = self.criterion(x_pred, target) return loss, None def predict_step(self, batch, batch_idx): idx_keep, idx_mask = self.mask_random_indices(batch) return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep) def mask_random_indices(self, batch): idx_keep, idx_mask = utils.random_token_mask( size=(batch["image"].shape[0], self.sequence_length), # (batch_size, seq_len) mask_ratio=self.mask_ratio, device=batch["image"].device, ) return idx_keep, idx_mask def predict(self, batch, idx_mask, idx_keep=None): # not used during training etc, only as a handy API # note the order of arguments is idx_mask first, as this is what most people change! # idx 0 is the class token and is never masked # user must add 1 to all indices before passing to predict! assumes this is already done assert idx_mask is not None if idx_keep is None: # probably a user only providing idx_mask, not using predict_step above all_indices = set(range(0, self.sequence_length)) idx_keep = [] for row in idx_mask: keep_row = list(all_indices - set(row.tolist())) idx_keep.append(keep_row) idx_keep = torch.tensor(idx_keep).to(idx_mask.device) images = batch["image"] batch_size = images.shape[0] x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) # get masked and reconstructed images im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images) # calculate MSE (copied from above, but with per-image reduction not per-batch reduction) patches = utils.patchify(images, self.patch_size) # does not change batch dim target = utils.get_at_index(patches, idx_mask - 1) mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target) mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) # reduce all dimensions but batch return { 'id_str': batch['id_str'], 'images': image_batch_to_pil_list(images), 'encoded': x_encoded, 'masked': image_batch_to_pil_list(im_masked), 'reconstructed': image_batch_to_pil_list(im_reconstructed), 'reconstruction_error': mse_per_image } def mask_and_reconstruct_images(self, mask, num_images, y, x): im_masked = self.patchify(x) # still the original image, just reshaped im_reconstructed = im_masked.clone() # same for now, but will become the reconstructed images # is mask is None, both masked and reconstructed are just the original image, do nothing # otherwise if mask is not None: for batch_index in range(num_images): # we ran out of images in the batch if batch_index >= x.shape[0] or batch_index > num_images: break # replace values with either 0 or the predicted fill values for mask_idx, token_idx in enumerate(mask[batch_index]): im_masked[batch_index, token_idx - 1] = 0 # set masked pixels to 0 im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] # set masked pixels to predicted pixels # depatchify i.e. reshape back like original image im_masked = self.unpatchify(im_masked) im_reconstructed = self.unpatchify(im_reconstructed) return im_masked, im_reconstructed def unpatchify(self, x): # i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size return einops.rearrange( x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", p1=self.patch_size, p2=self.patch_size, b=x.shape[0], c=3, h=int(np.sqrt(x.shape[1])), w=int(np.sqrt(x.shape[1])), ) def patchify(self, x): # confusingly, "h" here is height // patch size i.e. number of patches and p is patch size # in more normal terms # x is an image shape [b, c, h, w] # reshape to [b, n_patches^2/patch_size^2, patch_size^2*c] return einops.rearrange( x, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=self.patch_size, p2=self.patch_size, b=x.shape[0], c=3, h=x.shape[-2] // self.patch_size, w=x.shape[-1] // self.patch_size, ) @property def encoder(self): return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all) def image_batch_to_pil_list(images): images = einops.rearrange(images, 'b c h w -> b h w c') images = torch.clamp(images, 0, 1)*255 images = images.cpu().numpy() images = images.astype(np.uint8) # print(images.shape) return [Image.fromarray(im) for im in images]