mwalmsley's picture
text
e72e852
import gradio as gr
from huggingface_hub import hf_hub_download
from torchvision.transforms import v2
import omegaconf
import torch
from lightly.models.utils import random_token_mask
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
import mae_timm_simplified
# load model, global scope
cfg_path = hf_hub_download(repo_id="mwalmsley/euclid-rr2-mae", filename="config.yaml")
cfg = omegaconf.OmegaConf.load(cfg_path)
mae = mae_timm_simplified.MAE.from_pretrained("mwalmsley/euclid-rr2-mae", cfg=cfg)
def load_image_from_url(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
return img
def preprocess_image(image):
preprocess = transforms.Compose([
v2.ToImage(),
transforms.Resize((224, 224)),
v2.ToDtype(torch.float32, scale=True)
])
return preprocess(image)
def predict(x, mask_ratio=0.9):
# x is either PIL image or URL string
if isinstance(x, str):
image = load_image_from_url(x)
else:
image = x
image = preprocess_image(image)
batch = {
# 'image': torch.randn(1, 3, 224, 224),
'image': image.unsqueeze(0), # (1, 3, H, W)
'id_str': ['dummy' ],
} # dummy input
_, idx_mask = random_token_mask(
size=(1, mae.sequence_length), # (batch_size, seq_len)
mask_ratio=mask_ratio,
device=batch['image'].device,
)
with torch.no_grad():
result = mae.predict(batch, idx_mask=idx_mask)
return result['images'][0], result['masked'][0], result['reconstructed'][0]
default_url = 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519261632289994548_gz_arcsinh_vis_y.jpg'
# default_url = '"https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520070700291475966_gz_arcsinh_vis_y.jpg"'
sample_urls = [
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516202041290750678_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517952678291002259_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517613156287982372_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550504826285085169_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102020056/102020056_NEG573214944495746747_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102019124/102019124_NEG575999034503104669_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550052705286344755_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102158585/102158585_2699108800648101578_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102018213/102018213_NEG593362797514630636_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520447976287936395_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516572068288373807_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519801669290994144_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520376387292492930_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520535322288193494_gz_arcsinh_vis_y.jpg',
'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517834993290933807_gz_arcsinh_vis_y.jpg'
]
with gr.Blocks() as demo:
gr.Markdown("""
# Euclid Masked Autoencoder Demo
Masked Autoencoders (MAEs) are self-supervised learning model trained to reconstruct missing parts of input data. The MAE used here is trained on 3 million images of galaxies from Euclid RR2.
Select an image, upload your own image, or provide an image URL, and see how well the MAE can reconstruct the missing parts of the image.
Adjust the masking ratio to see how it affects the reconstruction quality. The MAE can often do well at 90% masking or more!
The model is available on [Hugging Face](https://huggingface.co/mwalmsley/euclid-rr2-mae) along with code to apply it locally.
For more details, see the workshop paper: Re-envisioning Euclid Galaxy Morphology: Identifying and Interpreting Features with Sparse Autoencoders, Wu & Walmsley, 2025, NeurIPS ML4Phys workshop.
----
""")
with gr.Row():
image_output = gr.Image(type="pil", label="Original Image", height=224, width=224)
masked_output = gr.Image(type="pil", label="Masked Image", height=224, width=224)
reconstructed_output = gr.Image(type="pil", label="Reconstructed Image", height=224, width=224)
outputs = [image_output, masked_output, reconstructed_output]
gr.Markdown('---')
with gr.Tab('Examples'):
explain = gr.Markdown("## Click an image to update")
with gr.Column(scale=0):
mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio", min_width=300, scale=0)
with gr.Row():
# wraps if needed
sample_images = [gr.Image(value=url, interactive=False, show_download_button=False, show_label=False, width=100, height=100) for url in sample_urls]
for im in sample_images:
im.select(predict, [im, mask_ratio_slider], outputs)
# mask ratio slider does nothing, requires image click, since we don't know which image to use otherwise
with gr.Tab("Image"):
image_input = gr.Image(type="pil", label="Input Image")
with gr.Row():
mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio")
slider_go = gr.Button("Update")
image_input.change(predict, [image_input, mask_ratio_slider], outputs)
mask_ratio_slider.change(predict, [image_input, mask_ratio_slider], outputs)
slider_go.click(predict, [image_input, mask_ratio_slider], outputs)
gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/) and on [Hugging Face](https://huggingface.co/mwalmsley/).')
with gr.Tab("URL"):
url_input = gr.Textbox(value=default_url, label="Image URL")
with gr.Row():
mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio")
slider_go = gr.Button("Update")
gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/). You can get URLs by right-clicking each image.')
url_input.change(predict, [url_input, mask_ratio_slider], outputs)
mask_ratio_slider.change(predict, [url_input, mask_ratio_slider], outputs)
slider_go.click(predict, [url_input, mask_ratio_slider], outputs)
gr.Markdown(
"""
---
*Walmsley trained the model and Wu ran the sparsity analysis. Additional thanks to Inigo Val Slijepcevic, Micah Bowles, Devina Mohan, Anna Scaife, and Joshua Speagle, for their help and advice. We are grateful to the Euclid Consortium and the European Space Agency for making the data available.*
""")
demo.launch()