|
|
|
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
import einops |
|
|
import numpy as np |
|
|
import torch |
|
|
from hydra.utils import instantiate |
|
|
from lightly.models import utils |
|
|
|
|
|
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM |
|
|
from timm.models.vision_transformer import VisionTransformer |
|
|
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
class MAE(torch.nn.Module, PyTorchModelHubMixin): |
|
|
|
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
|
|
|
vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size) |
|
|
|
|
|
self.patch_size = vit.patch_embed.patch_size[0] |
|
|
|
|
|
|
|
|
self.backbone = MaskedVisionTransformerTIMM(vit=vit) |
|
|
self.sequence_length = self.backbone.sequence_length |
|
|
|
|
|
self.encoder_dim = vit.embed_dim |
|
|
|
|
|
|
|
|
self.decoder = MAEDecoderTIMM( |
|
|
num_patches=vit.patch_embed.num_patches, |
|
|
patch_size=self.patch_size, |
|
|
embed_dim=vit.embed_dim, |
|
|
decoder_embed_dim=cfg.ssl_model.decoder.embed_dim, |
|
|
decoder_depth=cfg.ssl_model.decoder.depth, |
|
|
decoder_num_heads=cfg.ssl_model.decoder.num_heads, |
|
|
mlp_ratio=cfg.ssl_model.decoder.mlp_ratio, |
|
|
proj_drop_rate=cfg.ssl_model.decoder.dropout, |
|
|
attn_drop_rate=cfg.ssl_model.decoder.attention_dropout, |
|
|
) |
|
|
self.mask_ratio = cfg.ssl_model.mask_ratio |
|
|
|
|
|
self.criterion = torch.nn.MSELoss() |
|
|
|
|
|
def forward_encoder(self, images, idx_keep=None): |
|
|
return self.backbone.encode(images=images, idx_keep=idx_keep) |
|
|
|
|
|
def forward_decoder(self, x_encoded, idx_keep, idx_mask): |
|
|
|
|
|
batch_size = x_encoded.shape[0] |
|
|
x_decode = self.decoder.embed(x_encoded) |
|
|
x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length)) |
|
|
x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) |
|
|
|
|
|
|
|
|
x_decoded = self.decoder.decode(x_masked) |
|
|
|
|
|
|
|
|
x_pred = utils.get_at_index(x_decoded, idx_mask) |
|
|
x_pred = self.decoder.predict(x_pred) |
|
|
return x_pred |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
images = batch["image"] |
|
|
batch_size = images.shape[0] |
|
|
idx_keep, idx_mask = utils.random_token_mask( |
|
|
size=(batch_size, self.sequence_length), |
|
|
mask_ratio=self.mask_ratio, |
|
|
device=images.device, |
|
|
) |
|
|
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) |
|
|
|
|
|
|
|
|
|
|
|
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) |
|
|
|
|
|
|
|
|
patches = utils.patchify(images, self.patch_size) |
|
|
|
|
|
|
|
|
|
|
|
target = utils.get_at_index(patches, idx_mask - 1) |
|
|
|
|
|
loss = self.criterion(x_pred, target) |
|
|
|
|
|
return loss, x_encoded |
|
|
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx=0): |
|
|
images = batch["image"] |
|
|
batch_size = images.shape[0] |
|
|
idx_keep, idx_mask = utils.random_token_mask( |
|
|
size=(batch_size, self.sequence_length), |
|
|
mask_ratio=self.mask_ratio, |
|
|
device=images.device, |
|
|
) |
|
|
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) |
|
|
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) |
|
|
|
|
|
|
|
|
patches = utils.patchify(images, self.patch_size) |
|
|
|
|
|
target = utils.get_at_index(patches, idx_mask - 1) |
|
|
|
|
|
loss = self.criterion(x_pred, target) |
|
|
|
|
|
return loss, None |
|
|
|
|
|
def predict_step(self, batch, batch_idx): |
|
|
idx_keep, idx_mask = self.mask_random_indices(batch) |
|
|
return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep) |
|
|
|
|
|
def mask_random_indices(self, batch): |
|
|
idx_keep, idx_mask = utils.random_token_mask( |
|
|
size=(batch["image"].shape[0], self.sequence_length), |
|
|
mask_ratio=self.mask_ratio, |
|
|
device=batch["image"].device, |
|
|
) |
|
|
return idx_keep, idx_mask |
|
|
|
|
|
def predict(self, batch, idx_mask, idx_keep=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert idx_mask is not None |
|
|
|
|
|
if idx_keep is None: |
|
|
all_indices = set(range(0, self.sequence_length)) |
|
|
idx_keep = [] |
|
|
for row in idx_mask: |
|
|
keep_row = list(all_indices - set(row.tolist())) |
|
|
idx_keep.append(keep_row) |
|
|
idx_keep = torch.tensor(idx_keep).to(idx_mask.device) |
|
|
|
|
|
images = batch["image"] |
|
|
batch_size = images.shape[0] |
|
|
|
|
|
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) |
|
|
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) |
|
|
|
|
|
|
|
|
im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images) |
|
|
|
|
|
|
|
|
patches = utils.patchify(images, self.patch_size) |
|
|
target = utils.get_at_index(patches, idx_mask - 1) |
|
|
mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target) |
|
|
mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) |
|
|
|
|
|
return { |
|
|
'id_str': batch['id_str'], |
|
|
'images': image_batch_to_pil_list(images), |
|
|
'encoded': x_encoded, |
|
|
'masked': image_batch_to_pil_list(im_masked), |
|
|
'reconstructed': image_batch_to_pil_list(im_reconstructed), |
|
|
'reconstruction_error': mse_per_image |
|
|
} |
|
|
|
|
|
|
|
|
def mask_and_reconstruct_images(self, mask, num_images, y, x): |
|
|
im_masked = self.patchify(x) |
|
|
im_reconstructed = im_masked.clone() |
|
|
|
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
for batch_index in range(num_images): |
|
|
|
|
|
if batch_index >= x.shape[0] or batch_index > num_images: |
|
|
break |
|
|
|
|
|
for mask_idx, token_idx in enumerate(mask[batch_index]): |
|
|
im_masked[batch_index, token_idx - 1] = 0 |
|
|
im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] |
|
|
|
|
|
|
|
|
im_masked = self.unpatchify(im_masked) |
|
|
im_reconstructed = self.unpatchify(im_reconstructed) |
|
|
return im_masked, im_reconstructed |
|
|
|
|
|
def unpatchify(self, x): |
|
|
|
|
|
return einops.rearrange( |
|
|
x, |
|
|
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)", |
|
|
p1=self.patch_size, |
|
|
p2=self.patch_size, |
|
|
b=x.shape[0], |
|
|
c=3, |
|
|
h=int(np.sqrt(x.shape[1])), |
|
|
w=int(np.sqrt(x.shape[1])), |
|
|
) |
|
|
|
|
|
def patchify(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return einops.rearrange( |
|
|
x, |
|
|
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)", |
|
|
p1=self.patch_size, |
|
|
p2=self.patch_size, |
|
|
b=x.shape[0], |
|
|
c=3, |
|
|
h=x.shape[-2] // self.patch_size, |
|
|
w=x.shape[-1] // self.patch_size, |
|
|
) |
|
|
|
|
|
@property |
|
|
def encoder(self): |
|
|
return self.backbone.vit |
|
|
|
|
|
|
|
|
def image_batch_to_pil_list(images): |
|
|
images = einops.rearrange(images, 'b c h w -> b h w c') |
|
|
images = torch.clamp(images, 0, 1)*255 |
|
|
images = images.cpu().numpy() |
|
|
images = images.astype(np.uint8) |
|
|
|
|
|
return [Image.fromarray(im) for im in images] |
|
|
|