euclid_masked_autoencoder / mae_timm_simplified.py
mwalmsley's picture
initial commit
d88e92f
raw
history blame
9.34 kB
from PIL import Image
import einops
import numpy as np
import torch
from hydra.utils import instantiate
from lightly.models import utils
# https://docs.lightly.ai/self-supervised-learning/examples/mae.html
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from timm.models.vision_transformer import VisionTransformer
from huggingface_hub import PyTorchModelHubMixin
class MAE(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, cfg):
super().__init__()
vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size)
self.patch_size = vit.patch_embed.patch_size[0]
# Get MAE backbone
self.backbone = MaskedVisionTransformerTIMM(vit=vit)
self.sequence_length = self.backbone.sequence_length
self.encoder_dim = vit.embed_dim # for convenience later
# Get decoder
self.decoder = MAEDecoderTIMM(
num_patches=vit.patch_embed.num_patches,
patch_size=self.patch_size,
embed_dim=vit.embed_dim,
decoder_embed_dim=cfg.ssl_model.decoder.embed_dim,
decoder_depth=cfg.ssl_model.decoder.depth,
decoder_num_heads=cfg.ssl_model.decoder.num_heads,
mlp_ratio=cfg.ssl_model.decoder.mlp_ratio,
proj_drop_rate=cfg.ssl_model.decoder.dropout,
attn_drop_rate=cfg.ssl_model.decoder.attention_dropout,
)
self.mask_ratio = cfg.ssl_model.mask_ratio # saved as model parameter, not aug, since it is applied within model
self.criterion = torch.nn.MSELoss()
def forward_encoder(self, images, idx_keep=None):
return self.backbone.encode(images=images, idx_keep=idx_keep)
def forward_decoder(self, x_encoded, idx_keep, idx_mask):
# build decoder input
batch_size = x_encoded.shape[0]
x_decode = self.decoder.embed(x_encoded)
x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length))
x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))
# decoder forward pass
x_decoded = self.decoder.decode(x_masked)
# predict pixel values for masked tokens
x_pred = utils.get_at_index(x_decoded, idx_mask)
x_pred = self.decoder.predict(x_pred)
return x_pred
def training_step(self, batch, batch_idx):
images = batch["image"] # views contains only a single view
batch_size = images.shape[0]
idx_keep, idx_mask = utils.random_token_mask(
size=(batch_size, self.sequence_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
# decode and calculate loss (encoder no longer directly used)
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
# must adjust idx_mask for missing class token
# (class token was added after calculating which indices to mask,
# so we need to subtract 1 from idx_mask to get the new indices that are masked)
target = utils.get_at_index(patches, idx_mask - 1)
loss = self.criterion(x_pred, target)
return loss, x_encoded
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch["image"] # views contains only a single view
batch_size = images.shape[0]
idx_keep, idx_mask = utils.random_token_mask(
size=(batch_size, self.sequence_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
# must adjust idx_mask for missing class token
target = utils.get_at_index(patches, idx_mask - 1)
loss = self.criterion(x_pred, target)
return loss, None
def predict_step(self, batch, batch_idx):
idx_keep, idx_mask = self.mask_random_indices(batch)
return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep)
def mask_random_indices(self, batch):
idx_keep, idx_mask = utils.random_token_mask(
size=(batch["image"].shape[0], self.sequence_length), # (batch_size, seq_len)
mask_ratio=self.mask_ratio,
device=batch["image"].device,
)
return idx_keep, idx_mask
def predict(self, batch, idx_mask, idx_keep=None):
# not used during training etc, only as a handy API
# note the order of arguments is idx_mask first, as this is what most people change!
# idx 0 is the class token and is never masked
# user must add 1 to all indices before passing to predict! assumes this is already done
assert idx_mask is not None
if idx_keep is None: # probably a user only providing idx_mask, not using predict_step above
all_indices = set(range(0, self.sequence_length))
idx_keep = []
for row in idx_mask:
keep_row = list(all_indices - set(row.tolist()))
idx_keep.append(keep_row)
idx_keep = torch.tensor(idx_keep).to(idx_mask.device)
images = batch["image"]
batch_size = images.shape[0]
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
# get masked and reconstructed images
im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images)
# calculate MSE (copied from above, but with per-image reduction not per-batch reduction)
patches = utils.patchify(images, self.patch_size) # does not change batch dim
target = utils.get_at_index(patches, idx_mask - 1)
mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target)
mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) # reduce all dimensions but batch
return {
'id_str': batch['id_str'],
'images': image_batch_to_pil_list(images),
'encoded': x_encoded,
'masked': image_batch_to_pil_list(im_masked),
'reconstructed': image_batch_to_pil_list(im_reconstructed),
'reconstruction_error': mse_per_image
}
def mask_and_reconstruct_images(self, mask, num_images, y, x):
im_masked = self.patchify(x) # still the original image, just reshaped
im_reconstructed = im_masked.clone() # same for now, but will become the reconstructed images
# is mask is None, both masked and reconstructed are just the original image, do nothing
# otherwise
if mask is not None:
for batch_index in range(num_images):
# we ran out of images in the batch
if batch_index >= x.shape[0] or batch_index > num_images:
break
# replace values with either 0 or the predicted fill values
for mask_idx, token_idx in enumerate(mask[batch_index]):
im_masked[batch_index, token_idx - 1] = 0 # set masked pixels to 0
im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] # set masked pixels to predicted pixels
# depatchify i.e. reshape back like original image
im_masked = self.unpatchify(im_masked)
im_reconstructed = self.unpatchify(im_reconstructed)
return im_masked, im_reconstructed
def unpatchify(self, x):
# i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size
return einops.rearrange(
x,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
b=x.shape[0],
c=3,
h=int(np.sqrt(x.shape[1])),
w=int(np.sqrt(x.shape[1])),
)
def patchify(self, x):
# confusingly, "h" here is height // patch size i.e. number of patches and p is patch size
# in more normal terms
# x is an image shape [b, c, h, w]
# reshape to [b, n_patches^2/patch_size^2, patch_size^2*c]
return einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_size,
p2=self.patch_size,
b=x.shape[0],
c=3,
h=x.shape[-2] // self.patch_size,
w=x.shape[-1] // self.patch_size,
)
@property
def encoder(self):
return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all)
def image_batch_to_pil_list(images):
images = einops.rearrange(images, 'b c h w -> b h w c')
images = torch.clamp(images, 0, 1)*255
images = images.cpu().numpy()
images = images.astype(np.uint8)
# print(images.shape)
return [Image.fromarray(im) for im in images]