initial commit
Browse files- app.py +154 -0
- mae_timm_simplified.py +224 -0
- requirements.txt +7 -0
app.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from huggingface_hub import hf_hub_download
|
| 3 |
+
from torchvision.transforms import v2
|
| 4 |
+
|
| 5 |
+
import omegaconf
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from lightly.models.utils import random_token_mask
|
| 9 |
+
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import requests
|
| 13 |
+
from io import BytesIO
|
| 14 |
+
|
| 15 |
+
import mae_timm_simplified
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# load model, global scope
|
| 19 |
+
|
| 20 |
+
cfg_path = hf_hub_download(repo_id="mwalmsley/euclid-rr2-mae", filename="config.yaml")
|
| 21 |
+
cfg = omegaconf.OmegaConf.load(cfg_path)
|
| 22 |
+
mae = mae_timm_simplified.MAE.from_pretrained("mwalmsley/euclid-rr2-mae", cfg=cfg)
|
| 23 |
+
|
| 24 |
+
def load_image_from_url(url):
|
| 25 |
+
response = requests.get(url)
|
| 26 |
+
img = Image.open(BytesIO(response.content)).convert("RGB")
|
| 27 |
+
return img
|
| 28 |
+
|
| 29 |
+
def preprocess_image(image):
|
| 30 |
+
preprocess = transforms.Compose([
|
| 31 |
+
v2.ToImage(),
|
| 32 |
+
transforms.Resize((224, 224)),
|
| 33 |
+
v2.ToDtype(torch.float32, scale=True)
|
| 34 |
+
])
|
| 35 |
+
return preprocess(image)
|
| 36 |
+
|
| 37 |
+
def predict(x, mask_ratio=0.9):
|
| 38 |
+
|
| 39 |
+
# x is either PIL image or URL string
|
| 40 |
+
if isinstance(x, str):
|
| 41 |
+
image = load_image_from_url(x)
|
| 42 |
+
else:
|
| 43 |
+
image = x
|
| 44 |
+
|
| 45 |
+
image = preprocess_image(image)
|
| 46 |
+
|
| 47 |
+
batch = {
|
| 48 |
+
# 'image': torch.randn(1, 3, 224, 224),
|
| 49 |
+
'image': image.unsqueeze(0), # (1, 3, H, W)
|
| 50 |
+
'id_str': ['dummy' ],
|
| 51 |
+
} # dummy input
|
| 52 |
+
|
| 53 |
+
_, idx_mask = random_token_mask(
|
| 54 |
+
size=(1, mae.sequence_length), # (batch_size, seq_len)
|
| 55 |
+
mask_ratio=mask_ratio,
|
| 56 |
+
device=batch['image'].device,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
result = mae.predict(batch, idx_mask=idx_mask)
|
| 61 |
+
|
| 62 |
+
return result['images'][0], result['masked'][0], result['reconstructed'][0]
|
| 63 |
+
|
| 64 |
+
default_url = 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519261632289994548_gz_arcsinh_vis_y.jpg'
|
| 65 |
+
# default_url = '"https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520070700291475966_gz_arcsinh_vis_y.jpg"'
|
| 66 |
+
|
| 67 |
+
sample_urls = [
|
| 68 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516202041290750678_gz_arcsinh_vis_y.jpg',
|
| 69 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517952678291002259_gz_arcsinh_vis_y.jpg',
|
| 70 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517613156287982372_gz_arcsinh_vis_y.jpg',
|
| 71 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550504826285085169_gz_arcsinh_vis_y.jpg',
|
| 72 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102020056/102020056_NEG573214944495746747_gz_arcsinh_vis_y.jpg',
|
| 73 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102019124/102019124_NEG575999034503104669_gz_arcsinh_vis_y.jpg',
|
| 74 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550052705286344755_gz_arcsinh_vis_y.jpg',
|
| 75 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102158585/102158585_2699108800648101578_gz_arcsinh_vis_y.jpg',
|
| 76 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102018213/102018213_NEG593362797514630636_gz_arcsinh_vis_y.jpg',
|
| 77 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520447976287936395_gz_arcsinh_vis_y.jpg',
|
| 78 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516572068288373807_gz_arcsinh_vis_y.jpg',
|
| 79 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519801669290994144_gz_arcsinh_vis_y.jpg',
|
| 80 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520376387292492930_gz_arcsinh_vis_y.jpg',
|
| 81 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520535322288193494_gz_arcsinh_vis_y.jpg',
|
| 82 |
+
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517834993290933807_gz_arcsinh_vis_y.jpg'
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
with gr.Blocks() as demo:
|
| 87 |
+
|
| 88 |
+
gr.Markdown("""
|
| 89 |
+
# Euclid Masked Autoencoder Demo
|
| 90 |
+
|
| 91 |
+
Masked Autoencoders (MAEs) are self-supervised learning model trained to reconstruct missing parts of input data. The MAE used here is trained on 3 million images of galaxies from Euclid RR2.
|
| 92 |
+
|
| 93 |
+
Select an image, upload your own image, or provide an image URL, and see how well the MAE can reconstruct the missing parts of the image.
|
| 94 |
+
Adjust the masking ratio to see how it affects the reconstruction quality. The MAE can often do well at 90% masking or more!
|
| 95 |
+
|
| 96 |
+
The model is available on [Hugging Face](https://huggingface.co/mwalmsley/euclid-rr2-mae) along with code to apply it locally.
|
| 97 |
+
|
| 98 |
+
For more details, see the workshop paper: Re-envisioning Euclid Galaxy Morphology: Identifying and Interpreting Features with Sparse Autoencoders, Wu & Walmsley, 2025, NeurIPS ML4Phys workshop.
|
| 99 |
+
|
| 100 |
+
""")
|
| 101 |
+
|
| 102 |
+
with gr.Row():
|
| 103 |
+
image_output = gr.Image(type="pil", label="Original Image", height=224, width=224)
|
| 104 |
+
masked_output = gr.Image(type="pil", label="Masked Image", height=224, width=224)
|
| 105 |
+
reconstructed_output = gr.Image(type="pil", label="Reconstructed Image", height=224, width=224)
|
| 106 |
+
outputs = [image_output, masked_output, reconstructed_output]
|
| 107 |
+
|
| 108 |
+
with gr.Tab('Examples'):
|
| 109 |
+
with gr.Row():
|
| 110 |
+
# wraps if needed
|
| 111 |
+
sample_images = [gr.Image(value=url, interactive=False, show_download_button=False, show_label=False, width=100, height=100) for url in sample_urls]
|
| 112 |
+
|
| 113 |
+
with gr.Row():
|
| 114 |
+
mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio", min_width=400, scale=0)
|
| 115 |
+
explain = gr.Markdown("Click image to update")
|
| 116 |
+
|
| 117 |
+
for im in sample_images:
|
| 118 |
+
im.select(predict, [im, mask_ratio_slider], outputs)
|
| 119 |
+
# mask ratio slider does nothing, requires image click, since we don't know which image to use otherwise
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
with gr.Tab("Image"):
|
| 123 |
+
|
| 124 |
+
image_input = gr.Image(type="pil", label="Input Image")
|
| 125 |
+
|
| 126 |
+
with gr.Row():
|
| 127 |
+
mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio")
|
| 128 |
+
slider_go = gr.Button("Update")
|
| 129 |
+
image_input.change(predict, [image_input, mask_ratio_slider], outputs)
|
| 130 |
+
mask_ratio_slider.change(predict, [image_input, mask_ratio_slider], outputs)
|
| 131 |
+
slider_go.click(predict, [image_input, mask_ratio_slider], outputs)
|
| 132 |
+
|
| 133 |
+
gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/) and on [Hugging Face](https://huggingface.co/mwalmsley/).')
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
with gr.Tab("URL"):
|
| 137 |
+
url_input = gr.Textbox(value=default_url, label="Image URL")
|
| 138 |
+
with gr.Row():
|
| 139 |
+
mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio")
|
| 140 |
+
slider_go = gr.Button("Update")
|
| 141 |
+
gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/). You can get URLs by right-clicking each image.')
|
| 142 |
+
url_input.change(predict, [url_input, mask_ratio_slider], outputs)
|
| 143 |
+
mask_ratio_slider.change(predict, [url_input, mask_ratio_slider], outputs)
|
| 144 |
+
slider_go.click(predict, [url_input, mask_ratio_slider], outputs)
|
| 145 |
+
|
| 146 |
+
gr.Markdown(
|
| 147 |
+
"""
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
*Walmsley trained the model and Wu ran the sparsity analysis. Additional thanks to Inigo Val Slijepcevic, Micah Bowles, Devina Mohan, Anna Scaife, and Joshua Speagle, for their help and advice. We are grateful to the Euclid Consortium and the European Space Agency for making the data available.*
|
| 151 |
+
""")
|
| 152 |
+
|
| 153 |
+
demo.launch()
|
| 154 |
+
|
mae_timm_simplified.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
import einops
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from hydra.utils import instantiate
|
| 9 |
+
from lightly.models import utils
|
| 10 |
+
# https://docs.lightly.ai/self-supervised-learning/examples/mae.html
|
| 11 |
+
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
|
| 12 |
+
from timm.models.vision_transformer import VisionTransformer
|
| 13 |
+
|
| 14 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 15 |
+
class MAE(torch.nn.Module, PyTorchModelHubMixin):
|
| 16 |
+
|
| 17 |
+
def __init__(self, cfg):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size)
|
| 21 |
+
|
| 22 |
+
self.patch_size = vit.patch_embed.patch_size[0]
|
| 23 |
+
|
| 24 |
+
# Get MAE backbone
|
| 25 |
+
self.backbone = MaskedVisionTransformerTIMM(vit=vit)
|
| 26 |
+
self.sequence_length = self.backbone.sequence_length
|
| 27 |
+
|
| 28 |
+
self.encoder_dim = vit.embed_dim # for convenience later
|
| 29 |
+
|
| 30 |
+
# Get decoder
|
| 31 |
+
self.decoder = MAEDecoderTIMM(
|
| 32 |
+
num_patches=vit.patch_embed.num_patches,
|
| 33 |
+
patch_size=self.patch_size,
|
| 34 |
+
embed_dim=vit.embed_dim,
|
| 35 |
+
decoder_embed_dim=cfg.ssl_model.decoder.embed_dim,
|
| 36 |
+
decoder_depth=cfg.ssl_model.decoder.depth,
|
| 37 |
+
decoder_num_heads=cfg.ssl_model.decoder.num_heads,
|
| 38 |
+
mlp_ratio=cfg.ssl_model.decoder.mlp_ratio,
|
| 39 |
+
proj_drop_rate=cfg.ssl_model.decoder.dropout,
|
| 40 |
+
attn_drop_rate=cfg.ssl_model.decoder.attention_dropout,
|
| 41 |
+
)
|
| 42 |
+
self.mask_ratio = cfg.ssl_model.mask_ratio # saved as model parameter, not aug, since it is applied within model
|
| 43 |
+
|
| 44 |
+
self.criterion = torch.nn.MSELoss()
|
| 45 |
+
|
| 46 |
+
def forward_encoder(self, images, idx_keep=None):
|
| 47 |
+
return self.backbone.encode(images=images, idx_keep=idx_keep)
|
| 48 |
+
|
| 49 |
+
def forward_decoder(self, x_encoded, idx_keep, idx_mask):
|
| 50 |
+
# build decoder input
|
| 51 |
+
batch_size = x_encoded.shape[0]
|
| 52 |
+
x_decode = self.decoder.embed(x_encoded)
|
| 53 |
+
x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length))
|
| 54 |
+
x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))
|
| 55 |
+
|
| 56 |
+
# decoder forward pass
|
| 57 |
+
x_decoded = self.decoder.decode(x_masked)
|
| 58 |
+
|
| 59 |
+
# predict pixel values for masked tokens
|
| 60 |
+
x_pred = utils.get_at_index(x_decoded, idx_mask)
|
| 61 |
+
x_pred = self.decoder.predict(x_pred)
|
| 62 |
+
return x_pred
|
| 63 |
+
|
| 64 |
+
def training_step(self, batch, batch_idx):
|
| 65 |
+
images = batch["image"] # views contains only a single view
|
| 66 |
+
batch_size = images.shape[0]
|
| 67 |
+
idx_keep, idx_mask = utils.random_token_mask(
|
| 68 |
+
size=(batch_size, self.sequence_length),
|
| 69 |
+
mask_ratio=self.mask_ratio,
|
| 70 |
+
device=images.device,
|
| 71 |
+
)
|
| 72 |
+
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
|
| 73 |
+
|
| 74 |
+
# decode and calculate loss (encoder no longer directly used)
|
| 75 |
+
|
| 76 |
+
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
|
| 77 |
+
|
| 78 |
+
# get image patches for masked tokens
|
| 79 |
+
patches = utils.patchify(images, self.patch_size)
|
| 80 |
+
# must adjust idx_mask for missing class token
|
| 81 |
+
# (class token was added after calculating which indices to mask,
|
| 82 |
+
# so we need to subtract 1 from idx_mask to get the new indices that are masked)
|
| 83 |
+
target = utils.get_at_index(patches, idx_mask - 1)
|
| 84 |
+
|
| 85 |
+
loss = self.criterion(x_pred, target)
|
| 86 |
+
|
| 87 |
+
return loss, x_encoded
|
| 88 |
+
|
| 89 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
| 90 |
+
images = batch["image"] # views contains only a single view
|
| 91 |
+
batch_size = images.shape[0]
|
| 92 |
+
idx_keep, idx_mask = utils.random_token_mask(
|
| 93 |
+
size=(batch_size, self.sequence_length),
|
| 94 |
+
mask_ratio=self.mask_ratio,
|
| 95 |
+
device=images.device,
|
| 96 |
+
)
|
| 97 |
+
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
|
| 98 |
+
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
|
| 99 |
+
|
| 100 |
+
# get image patches for masked tokens
|
| 101 |
+
patches = utils.patchify(images, self.patch_size)
|
| 102 |
+
# must adjust idx_mask for missing class token
|
| 103 |
+
target = utils.get_at_index(patches, idx_mask - 1)
|
| 104 |
+
|
| 105 |
+
loss = self.criterion(x_pred, target)
|
| 106 |
+
|
| 107 |
+
return loss, None
|
| 108 |
+
|
| 109 |
+
def predict_step(self, batch, batch_idx):
|
| 110 |
+
idx_keep, idx_mask = self.mask_random_indices(batch)
|
| 111 |
+
return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep)
|
| 112 |
+
|
| 113 |
+
def mask_random_indices(self, batch):
|
| 114 |
+
idx_keep, idx_mask = utils.random_token_mask(
|
| 115 |
+
size=(batch["image"].shape[0], self.sequence_length), # (batch_size, seq_len)
|
| 116 |
+
mask_ratio=self.mask_ratio,
|
| 117 |
+
device=batch["image"].device,
|
| 118 |
+
)
|
| 119 |
+
return idx_keep, idx_mask
|
| 120 |
+
|
| 121 |
+
def predict(self, batch, idx_mask, idx_keep=None):
|
| 122 |
+
# not used during training etc, only as a handy API
|
| 123 |
+
# note the order of arguments is idx_mask first, as this is what most people change!
|
| 124 |
+
|
| 125 |
+
# idx 0 is the class token and is never masked
|
| 126 |
+
# user must add 1 to all indices before passing to predict! assumes this is already done
|
| 127 |
+
|
| 128 |
+
assert idx_mask is not None
|
| 129 |
+
|
| 130 |
+
if idx_keep is None: # probably a user only providing idx_mask, not using predict_step above
|
| 131 |
+
all_indices = set(range(0, self.sequence_length))
|
| 132 |
+
idx_keep = []
|
| 133 |
+
for row in idx_mask:
|
| 134 |
+
keep_row = list(all_indices - set(row.tolist()))
|
| 135 |
+
idx_keep.append(keep_row)
|
| 136 |
+
idx_keep = torch.tensor(idx_keep).to(idx_mask.device)
|
| 137 |
+
|
| 138 |
+
images = batch["image"]
|
| 139 |
+
batch_size = images.shape[0]
|
| 140 |
+
|
| 141 |
+
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
|
| 142 |
+
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
|
| 143 |
+
|
| 144 |
+
# get masked and reconstructed images
|
| 145 |
+
im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images)
|
| 146 |
+
|
| 147 |
+
# calculate MSE (copied from above, but with per-image reduction not per-batch reduction)
|
| 148 |
+
patches = utils.patchify(images, self.patch_size) # does not change batch dim
|
| 149 |
+
target = utils.get_at_index(patches, idx_mask - 1)
|
| 150 |
+
mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target)
|
| 151 |
+
mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) # reduce all dimensions but batch
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
'id_str': batch['id_str'],
|
| 155 |
+
'images': image_batch_to_pil_list(images),
|
| 156 |
+
'encoded': x_encoded,
|
| 157 |
+
'masked': image_batch_to_pil_list(im_masked),
|
| 158 |
+
'reconstructed': image_batch_to_pil_list(im_reconstructed),
|
| 159 |
+
'reconstruction_error': mse_per_image
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def mask_and_reconstruct_images(self, mask, num_images, y, x):
|
| 164 |
+
im_masked = self.patchify(x) # still the original image, just reshaped
|
| 165 |
+
im_reconstructed = im_masked.clone() # same for now, but will become the reconstructed images
|
| 166 |
+
|
| 167 |
+
# is mask is None, both masked and reconstructed are just the original image, do nothing
|
| 168 |
+
# otherwise
|
| 169 |
+
if mask is not None:
|
| 170 |
+
for batch_index in range(num_images):
|
| 171 |
+
# we ran out of images in the batch
|
| 172 |
+
if batch_index >= x.shape[0] or batch_index > num_images:
|
| 173 |
+
break
|
| 174 |
+
# replace values with either 0 or the predicted fill values
|
| 175 |
+
for mask_idx, token_idx in enumerate(mask[batch_index]):
|
| 176 |
+
im_masked[batch_index, token_idx - 1] = 0 # set masked pixels to 0
|
| 177 |
+
im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] # set masked pixels to predicted pixels
|
| 178 |
+
|
| 179 |
+
# depatchify i.e. reshape back like original image
|
| 180 |
+
im_masked = self.unpatchify(im_masked)
|
| 181 |
+
im_reconstructed = self.unpatchify(im_reconstructed)
|
| 182 |
+
return im_masked, im_reconstructed
|
| 183 |
+
|
| 184 |
+
def unpatchify(self, x):
|
| 185 |
+
# i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size
|
| 186 |
+
return einops.rearrange(
|
| 187 |
+
x,
|
| 188 |
+
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
|
| 189 |
+
p1=self.patch_size,
|
| 190 |
+
p2=self.patch_size,
|
| 191 |
+
b=x.shape[0],
|
| 192 |
+
c=3,
|
| 193 |
+
h=int(np.sqrt(x.shape[1])),
|
| 194 |
+
w=int(np.sqrt(x.shape[1])),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def patchify(self, x):
|
| 198 |
+
# confusingly, "h" here is height // patch size i.e. number of patches and p is patch size
|
| 199 |
+
# in more normal terms
|
| 200 |
+
# x is an image shape [b, c, h, w]
|
| 201 |
+
# reshape to [b, n_patches^2/patch_size^2, patch_size^2*c]
|
| 202 |
+
return einops.rearrange(
|
| 203 |
+
x,
|
| 204 |
+
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
|
| 205 |
+
p1=self.patch_size,
|
| 206 |
+
p2=self.patch_size,
|
| 207 |
+
b=x.shape[0],
|
| 208 |
+
c=3,
|
| 209 |
+
h=x.shape[-2] // self.patch_size,
|
| 210 |
+
w=x.shape[-1] // self.patch_size,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def encoder(self):
|
| 215 |
+
return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def image_batch_to_pil_list(images):
|
| 219 |
+
images = einops.rearrange(images, 'b c h w -> b h w c')
|
| 220 |
+
images = torch.clamp(images, 0, 1)*255
|
| 221 |
+
images = images.cpu().numpy()
|
| 222 |
+
images = images.astype(np.uint8)
|
| 223 |
+
# print(images.shape)
|
| 224 |
+
return [Image.fromarray(im) for im in images]
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
timm
|
| 2 |
+
lightly
|
| 3 |
+
hydra-core
|
| 4 |
+
omegaconf
|
| 5 |
+
einops
|
| 6 |
+
pillow
|
| 7 |
+
torch
|