File size: 8,400 Bytes
d88e92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1630cbe
 
d88e92f
 
 
 
 
 
 
 
1630cbe
 
d88e92f
1630cbe
e72e852
1630cbe
 
 
d88e92f
 
 
 
1630cbe
d88e92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
from huggingface_hub import hf_hub_download
from torchvision.transforms import v2

import omegaconf
import torch

from lightly.models.utils import random_token_mask

from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO

import mae_timm_simplified


# load model, global scope

cfg_path = hf_hub_download(repo_id="mwalmsley/euclid-rr2-mae", filename="config.yaml")
cfg = omegaconf.OmegaConf.load(cfg_path)
mae = mae_timm_simplified.MAE.from_pretrained("mwalmsley/euclid-rr2-mae", cfg=cfg)

def load_image_from_url(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    return img

def preprocess_image(image):
    preprocess = transforms.Compose([
        v2.ToImage(),
        transforms.Resize((224, 224)),
        v2.ToDtype(torch.float32, scale=True)
    ])
    return preprocess(image)

def predict(x, mask_ratio=0.9):
    
    # x is either PIL image or URL string
    if isinstance(x, str):
        image = load_image_from_url(x)
    else:
        image = x

    image = preprocess_image(image)

    batch = {
        # 'image': torch.randn(1, 3, 224, 224),
        'image': image.unsqueeze(0),  # (1, 3, H, W)
        'id_str': ['dummy' ],
        }  # dummy input

    _, idx_mask = random_token_mask(
        size=(1, mae.sequence_length),  # (batch_size, seq_len)
        mask_ratio=mask_ratio,
        device=batch['image'].device,
    )

    with torch.no_grad():
        result = mae.predict(batch, idx_mask=idx_mask)

    return result['images'][0], result['masked'][0], result['reconstructed'][0]

default_url = 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519261632289994548_gz_arcsinh_vis_y.jpg'
# default_url = '"https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520070700291475966_gz_arcsinh_vis_y.jpg"'

sample_urls = [
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516202041290750678_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517952678291002259_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517613156287982372_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550504826285085169_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102020056/102020056_NEG573214944495746747_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102019124/102019124_NEG575999034503104669_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550052705286344755_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102158585/102158585_2699108800648101578_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102018213/102018213_NEG593362797514630636_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520447976287936395_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516572068288373807_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519801669290994144_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520376387292492930_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520535322288193494_gz_arcsinh_vis_y.jpg',
    'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517834993290933807_gz_arcsinh_vis_y.jpg'
]


with gr.Blocks() as demo:

    gr.Markdown("""
    # Euclid Masked Autoencoder Demo

    Masked Autoencoders (MAEs) are self-supervised learning model trained to reconstruct missing parts of input data. The MAE used here is trained on 3 million images of galaxies from Euclid RR2. 
                 
    Select an image, upload your own image, or provide an image URL, and see how well the MAE can reconstruct the missing parts of the image.            
    Adjust the masking ratio to see how it affects the reconstruction quality. The MAE can often do well at 90% masking or more!
                
    The model is available on [Hugging Face](https://huggingface.co/mwalmsley/euclid-rr2-mae) along with code to apply it locally.
                
    For more details, see the workshop paper: Re-envisioning Euclid Galaxy Morphology: Identifying and Interpreting Features with Sparse Autoencoders, Wu & Walmsley, 2025, NeurIPS ML4Phys workshop. 

    ----
                
    """)

    with gr.Row():
        image_output = gr.Image(type="pil", label="Original Image", height=224, width=224)
        masked_output = gr.Image(type="pil", label="Masked Image", height=224, width=224)
        reconstructed_output = gr.Image(type="pil", label="Reconstructed Image", height=224, width=224)
        outputs = [image_output, masked_output, reconstructed_output]

    gr.Markdown('---')

    with gr.Tab('Examples'):

        explain = gr.Markdown("## Click an image to update")
        with gr.Column(scale=0):
            mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio", min_width=300, scale=0)

        with gr.Row():
            # wraps if needed
            sample_images = [gr.Image(value=url, interactive=False, show_download_button=False, show_label=False, width=100, height=100) for url in sample_urls]


        for im in sample_images:
            im.select(predict, [im, mask_ratio_slider], outputs)
        # mask ratio slider does nothing, requires image click, since we don't know which image to use otherwise


    with gr.Tab("Image"):

        image_input = gr.Image(type="pil", label="Input Image")
            
        with gr.Row():
            mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio")
            slider_go = gr.Button("Update")
        image_input.change(predict, [image_input, mask_ratio_slider], outputs)
        mask_ratio_slider.change(predict, [image_input, mask_ratio_slider], outputs)
        slider_go.click(predict, [image_input, mask_ratio_slider], outputs)

        gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/) and on [Hugging Face](https://huggingface.co/mwalmsley/).')


    with gr.Tab("URL"):
        url_input = gr.Textbox(value=default_url, label="Image URL")
        with gr.Row():
            mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio")
            slider_go = gr.Button("Update")
        gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/). You can get URLs by right-clicking each image.')
        url_input.change(predict, [url_input, mask_ratio_slider], outputs)
        mask_ratio_slider.change(predict, [url_input, mask_ratio_slider], outputs)
        slider_go.click(predict, [url_input, mask_ratio_slider], outputs)
        
    gr.Markdown(
        """
        ---
                        
        *Walmsley trained the model and Wu ran the sparsity analysis. Additional thanks to Inigo Val Slijepcevic, Micah Bowles, Devina Mohan, Anna Scaife, and Joshua Speagle, for their help and advice. We are grateful to the Euclid Consortium and the European Space Agency for making the data available.*
        """)

demo.launch()