mwalmsley commited on
Commit
d88e92f
·
1 Parent(s): d6d433e

initial commit

Browse files
Files changed (3) hide show
  1. app.py +154 -0
  2. mae_timm_simplified.py +224 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from torchvision.transforms import v2
4
+
5
+ import omegaconf
6
+ import torch
7
+
8
+ from lightly.models.utils import random_token_mask
9
+
10
+ from torchvision import transforms
11
+ from PIL import Image
12
+ import requests
13
+ from io import BytesIO
14
+
15
+ import mae_timm_simplified
16
+
17
+
18
+ # load model, global scope
19
+
20
+ cfg_path = hf_hub_download(repo_id="mwalmsley/euclid-rr2-mae", filename="config.yaml")
21
+ cfg = omegaconf.OmegaConf.load(cfg_path)
22
+ mae = mae_timm_simplified.MAE.from_pretrained("mwalmsley/euclid-rr2-mae", cfg=cfg)
23
+
24
+ def load_image_from_url(url):
25
+ response = requests.get(url)
26
+ img = Image.open(BytesIO(response.content)).convert("RGB")
27
+ return img
28
+
29
+ def preprocess_image(image):
30
+ preprocess = transforms.Compose([
31
+ v2.ToImage(),
32
+ transforms.Resize((224, 224)),
33
+ v2.ToDtype(torch.float32, scale=True)
34
+ ])
35
+ return preprocess(image)
36
+
37
+ def predict(x, mask_ratio=0.9):
38
+
39
+ # x is either PIL image or URL string
40
+ if isinstance(x, str):
41
+ image = load_image_from_url(x)
42
+ else:
43
+ image = x
44
+
45
+ image = preprocess_image(image)
46
+
47
+ batch = {
48
+ # 'image': torch.randn(1, 3, 224, 224),
49
+ 'image': image.unsqueeze(0), # (1, 3, H, W)
50
+ 'id_str': ['dummy' ],
51
+ } # dummy input
52
+
53
+ _, idx_mask = random_token_mask(
54
+ size=(1, mae.sequence_length), # (batch_size, seq_len)
55
+ mask_ratio=mask_ratio,
56
+ device=batch['image'].device,
57
+ )
58
+
59
+ with torch.no_grad():
60
+ result = mae.predict(batch, idx_mask=idx_mask)
61
+
62
+ return result['images'][0], result['masked'][0], result['reconstructed'][0]
63
+
64
+ default_url = 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519261632289994548_gz_arcsinh_vis_y.jpg'
65
+ # default_url = '"https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520070700291475966_gz_arcsinh_vis_y.jpg"'
66
+
67
+ sample_urls = [
68
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516202041290750678_gz_arcsinh_vis_y.jpg',
69
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517952678291002259_gz_arcsinh_vis_y.jpg',
70
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517613156287982372_gz_arcsinh_vis_y.jpg',
71
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550504826285085169_gz_arcsinh_vis_y.jpg',
72
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102020056/102020056_NEG573214944495746747_gz_arcsinh_vis_y.jpg',
73
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102019124/102019124_NEG575999034503104669_gz_arcsinh_vis_y.jpg',
74
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550052705286344755_gz_arcsinh_vis_y.jpg',
75
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102158585/102158585_2699108800648101578_gz_arcsinh_vis_y.jpg',
76
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102018213/102018213_NEG593362797514630636_gz_arcsinh_vis_y.jpg',
77
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520447976287936395_gz_arcsinh_vis_y.jpg',
78
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516572068288373807_gz_arcsinh_vis_y.jpg',
79
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519801669290994144_gz_arcsinh_vis_y.jpg',
80
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520376387292492930_gz_arcsinh_vis_y.jpg',
81
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520535322288193494_gz_arcsinh_vis_y.jpg',
82
+ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517834993290933807_gz_arcsinh_vis_y.jpg'
83
+ ]
84
+
85
+
86
+ with gr.Blocks() as demo:
87
+
88
+ gr.Markdown("""
89
+ # Euclid Masked Autoencoder Demo
90
+
91
+ Masked Autoencoders (MAEs) are self-supervised learning model trained to reconstruct missing parts of input data. The MAE used here is trained on 3 million images of galaxies from Euclid RR2.
92
+
93
+ Select an image, upload your own image, or provide an image URL, and see how well the MAE can reconstruct the missing parts of the image.
94
+ Adjust the masking ratio to see how it affects the reconstruction quality. The MAE can often do well at 90% masking or more!
95
+
96
+ The model is available on [Hugging Face](https://huggingface.co/mwalmsley/euclid-rr2-mae) along with code to apply it locally.
97
+
98
+ For more details, see the workshop paper: Re-envisioning Euclid Galaxy Morphology: Identifying and Interpreting Features with Sparse Autoencoders, Wu & Walmsley, 2025, NeurIPS ML4Phys workshop.
99
+
100
+ """)
101
+
102
+ with gr.Row():
103
+ image_output = gr.Image(type="pil", label="Original Image", height=224, width=224)
104
+ masked_output = gr.Image(type="pil", label="Masked Image", height=224, width=224)
105
+ reconstructed_output = gr.Image(type="pil", label="Reconstructed Image", height=224, width=224)
106
+ outputs = [image_output, masked_output, reconstructed_output]
107
+
108
+ with gr.Tab('Examples'):
109
+ with gr.Row():
110
+ # wraps if needed
111
+ sample_images = [gr.Image(value=url, interactive=False, show_download_button=False, show_label=False, width=100, height=100) for url in sample_urls]
112
+
113
+ with gr.Row():
114
+ mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio", min_width=400, scale=0)
115
+ explain = gr.Markdown("Click image to update")
116
+
117
+ for im in sample_images:
118
+ im.select(predict, [im, mask_ratio_slider], outputs)
119
+ # mask ratio slider does nothing, requires image click, since we don't know which image to use otherwise
120
+
121
+
122
+ with gr.Tab("Image"):
123
+
124
+ image_input = gr.Image(type="pil", label="Input Image")
125
+
126
+ with gr.Row():
127
+ mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio")
128
+ slider_go = gr.Button("Update")
129
+ image_input.change(predict, [image_input, mask_ratio_slider], outputs)
130
+ mask_ratio_slider.change(predict, [image_input, mask_ratio_slider], outputs)
131
+ slider_go.click(predict, [image_input, mask_ratio_slider], outputs)
132
+
133
+ gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/) and on [Hugging Face](https://huggingface.co/mwalmsley/).')
134
+
135
+
136
+ with gr.Tab("URL"):
137
+ url_input = gr.Textbox(value=default_url, label="Image URL")
138
+ with gr.Row():
139
+ mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio")
140
+ slider_go = gr.Button("Update")
141
+ gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/). You can get URLs by right-clicking each image.')
142
+ url_input.change(predict, [url_input, mask_ratio_slider], outputs)
143
+ mask_ratio_slider.change(predict, [url_input, mask_ratio_slider], outputs)
144
+ slider_go.click(predict, [url_input, mask_ratio_slider], outputs)
145
+
146
+ gr.Markdown(
147
+ """
148
+ ---
149
+
150
+ *Walmsley trained the model and Wu ran the sparsity analysis. Additional thanks to Inigo Val Slijepcevic, Micah Bowles, Devina Mohan, Anna Scaife, and Joshua Speagle, for their help and advice. We are grateful to the Euclid Consortium and the European Space Agency for making the data available.*
151
+ """)
152
+
153
+ demo.launch()
154
+
mae_timm_simplified.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from PIL import Image
4
+
5
+ import einops
6
+ import numpy as np
7
+ import torch
8
+ from hydra.utils import instantiate
9
+ from lightly.models import utils
10
+ # https://docs.lightly.ai/self-supervised-learning/examples/mae.html
11
+ from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
12
+ from timm.models.vision_transformer import VisionTransformer
13
+
14
+ from huggingface_hub import PyTorchModelHubMixin
15
+ class MAE(torch.nn.Module, PyTorchModelHubMixin):
16
+
17
+ def __init__(self, cfg):
18
+ super().__init__()
19
+
20
+ vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size)
21
+
22
+ self.patch_size = vit.patch_embed.patch_size[0]
23
+
24
+ # Get MAE backbone
25
+ self.backbone = MaskedVisionTransformerTIMM(vit=vit)
26
+ self.sequence_length = self.backbone.sequence_length
27
+
28
+ self.encoder_dim = vit.embed_dim # for convenience later
29
+
30
+ # Get decoder
31
+ self.decoder = MAEDecoderTIMM(
32
+ num_patches=vit.patch_embed.num_patches,
33
+ patch_size=self.patch_size,
34
+ embed_dim=vit.embed_dim,
35
+ decoder_embed_dim=cfg.ssl_model.decoder.embed_dim,
36
+ decoder_depth=cfg.ssl_model.decoder.depth,
37
+ decoder_num_heads=cfg.ssl_model.decoder.num_heads,
38
+ mlp_ratio=cfg.ssl_model.decoder.mlp_ratio,
39
+ proj_drop_rate=cfg.ssl_model.decoder.dropout,
40
+ attn_drop_rate=cfg.ssl_model.decoder.attention_dropout,
41
+ )
42
+ self.mask_ratio = cfg.ssl_model.mask_ratio # saved as model parameter, not aug, since it is applied within model
43
+
44
+ self.criterion = torch.nn.MSELoss()
45
+
46
+ def forward_encoder(self, images, idx_keep=None):
47
+ return self.backbone.encode(images=images, idx_keep=idx_keep)
48
+
49
+ def forward_decoder(self, x_encoded, idx_keep, idx_mask):
50
+ # build decoder input
51
+ batch_size = x_encoded.shape[0]
52
+ x_decode = self.decoder.embed(x_encoded)
53
+ x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length))
54
+ x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))
55
+
56
+ # decoder forward pass
57
+ x_decoded = self.decoder.decode(x_masked)
58
+
59
+ # predict pixel values for masked tokens
60
+ x_pred = utils.get_at_index(x_decoded, idx_mask)
61
+ x_pred = self.decoder.predict(x_pred)
62
+ return x_pred
63
+
64
+ def training_step(self, batch, batch_idx):
65
+ images = batch["image"] # views contains only a single view
66
+ batch_size = images.shape[0]
67
+ idx_keep, idx_mask = utils.random_token_mask(
68
+ size=(batch_size, self.sequence_length),
69
+ mask_ratio=self.mask_ratio,
70
+ device=images.device,
71
+ )
72
+ x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
73
+
74
+ # decode and calculate loss (encoder no longer directly used)
75
+
76
+ x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
77
+
78
+ # get image patches for masked tokens
79
+ patches = utils.patchify(images, self.patch_size)
80
+ # must adjust idx_mask for missing class token
81
+ # (class token was added after calculating which indices to mask,
82
+ # so we need to subtract 1 from idx_mask to get the new indices that are masked)
83
+ target = utils.get_at_index(patches, idx_mask - 1)
84
+
85
+ loss = self.criterion(x_pred, target)
86
+
87
+ return loss, x_encoded
88
+
89
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
90
+ images = batch["image"] # views contains only a single view
91
+ batch_size = images.shape[0]
92
+ idx_keep, idx_mask = utils.random_token_mask(
93
+ size=(batch_size, self.sequence_length),
94
+ mask_ratio=self.mask_ratio,
95
+ device=images.device,
96
+ )
97
+ x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
98
+ x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
99
+
100
+ # get image patches for masked tokens
101
+ patches = utils.patchify(images, self.patch_size)
102
+ # must adjust idx_mask for missing class token
103
+ target = utils.get_at_index(patches, idx_mask - 1)
104
+
105
+ loss = self.criterion(x_pred, target)
106
+
107
+ return loss, None
108
+
109
+ def predict_step(self, batch, batch_idx):
110
+ idx_keep, idx_mask = self.mask_random_indices(batch)
111
+ return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep)
112
+
113
+ def mask_random_indices(self, batch):
114
+ idx_keep, idx_mask = utils.random_token_mask(
115
+ size=(batch["image"].shape[0], self.sequence_length), # (batch_size, seq_len)
116
+ mask_ratio=self.mask_ratio,
117
+ device=batch["image"].device,
118
+ )
119
+ return idx_keep, idx_mask
120
+
121
+ def predict(self, batch, idx_mask, idx_keep=None):
122
+ # not used during training etc, only as a handy API
123
+ # note the order of arguments is idx_mask first, as this is what most people change!
124
+
125
+ # idx 0 is the class token and is never masked
126
+ # user must add 1 to all indices before passing to predict! assumes this is already done
127
+
128
+ assert idx_mask is not None
129
+
130
+ if idx_keep is None: # probably a user only providing idx_mask, not using predict_step above
131
+ all_indices = set(range(0, self.sequence_length))
132
+ idx_keep = []
133
+ for row in idx_mask:
134
+ keep_row = list(all_indices - set(row.tolist()))
135
+ idx_keep.append(keep_row)
136
+ idx_keep = torch.tensor(idx_keep).to(idx_mask.device)
137
+
138
+ images = batch["image"]
139
+ batch_size = images.shape[0]
140
+
141
+ x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
142
+ x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
143
+
144
+ # get masked and reconstructed images
145
+ im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images)
146
+
147
+ # calculate MSE (copied from above, but with per-image reduction not per-batch reduction)
148
+ patches = utils.patchify(images, self.patch_size) # does not change batch dim
149
+ target = utils.get_at_index(patches, idx_mask - 1)
150
+ mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target)
151
+ mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) # reduce all dimensions but batch
152
+
153
+ return {
154
+ 'id_str': batch['id_str'],
155
+ 'images': image_batch_to_pil_list(images),
156
+ 'encoded': x_encoded,
157
+ 'masked': image_batch_to_pil_list(im_masked),
158
+ 'reconstructed': image_batch_to_pil_list(im_reconstructed),
159
+ 'reconstruction_error': mse_per_image
160
+ }
161
+
162
+
163
+ def mask_and_reconstruct_images(self, mask, num_images, y, x):
164
+ im_masked = self.patchify(x) # still the original image, just reshaped
165
+ im_reconstructed = im_masked.clone() # same for now, but will become the reconstructed images
166
+
167
+ # is mask is None, both masked and reconstructed are just the original image, do nothing
168
+ # otherwise
169
+ if mask is not None:
170
+ for batch_index in range(num_images):
171
+ # we ran out of images in the batch
172
+ if batch_index >= x.shape[0] or batch_index > num_images:
173
+ break
174
+ # replace values with either 0 or the predicted fill values
175
+ for mask_idx, token_idx in enumerate(mask[batch_index]):
176
+ im_masked[batch_index, token_idx - 1] = 0 # set masked pixels to 0
177
+ im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] # set masked pixels to predicted pixels
178
+
179
+ # depatchify i.e. reshape back like original image
180
+ im_masked = self.unpatchify(im_masked)
181
+ im_reconstructed = self.unpatchify(im_reconstructed)
182
+ return im_masked, im_reconstructed
183
+
184
+ def unpatchify(self, x):
185
+ # i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size
186
+ return einops.rearrange(
187
+ x,
188
+ "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
189
+ p1=self.patch_size,
190
+ p2=self.patch_size,
191
+ b=x.shape[0],
192
+ c=3,
193
+ h=int(np.sqrt(x.shape[1])),
194
+ w=int(np.sqrt(x.shape[1])),
195
+ )
196
+
197
+ def patchify(self, x):
198
+ # confusingly, "h" here is height // patch size i.e. number of patches and p is patch size
199
+ # in more normal terms
200
+ # x is an image shape [b, c, h, w]
201
+ # reshape to [b, n_patches^2/patch_size^2, patch_size^2*c]
202
+ return einops.rearrange(
203
+ x,
204
+ "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
205
+ p1=self.patch_size,
206
+ p2=self.patch_size,
207
+ b=x.shape[0],
208
+ c=3,
209
+ h=x.shape[-2] // self.patch_size,
210
+ w=x.shape[-1] // self.patch_size,
211
+ )
212
+
213
+ @property
214
+ def encoder(self):
215
+ return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all)
216
+
217
+
218
+ def image_batch_to_pil_list(images):
219
+ images = einops.rearrange(images, 'b c h w -> b h w c')
220
+ images = torch.clamp(images, 0, 1)*255
221
+ images = images.cpu().numpy()
222
+ images = images.astype(np.uint8)
223
+ # print(images.shape)
224
+ return [Image.fromarray(im) for im in images]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ timm
2
+ lightly
3
+ hydra-core
4
+ omegaconf
5
+ einops
6
+ pillow
7
+ torch