import gradio as gr from huggingface_hub import hf_hub_download from torchvision.transforms import v2 import omegaconf import torch from lightly.models.utils import random_token_mask from torchvision import transforms from PIL import Image import requests from io import BytesIO import mae_timm_simplified # load model, global scope cfg_path = hf_hub_download(repo_id="mwalmsley/euclid-rr2-mae", filename="config.yaml") cfg = omegaconf.OmegaConf.load(cfg_path) mae = mae_timm_simplified.MAE.from_pretrained("mwalmsley/euclid-rr2-mae", cfg=cfg) def load_image_from_url(url): response = requests.get(url) img = Image.open(BytesIO(response.content)).convert("RGB") return img def preprocess_image(image): preprocess = transforms.Compose([ v2.ToImage(), transforms.Resize((224, 224)), v2.ToDtype(torch.float32, scale=True) ]) return preprocess(image) def predict(x, mask_ratio=0.9): # x is either PIL image or URL string if isinstance(x, str): image = load_image_from_url(x) else: image = x image = preprocess_image(image) batch = { # 'image': torch.randn(1, 3, 224, 224), 'image': image.unsqueeze(0), # (1, 3, H, W) 'id_str': ['dummy' ], } # dummy input _, idx_mask = random_token_mask( size=(1, mae.sequence_length), # (batch_size, seq_len) mask_ratio=mask_ratio, device=batch['image'].device, ) with torch.no_grad(): result = mae.predict(batch, idx_mask=idx_mask) return result['images'][0], result['masked'][0], result['reconstructed'][0] default_url = 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519261632289994548_gz_arcsinh_vis_y.jpg' # default_url = '"https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520070700291475966_gz_arcsinh_vis_y.jpg"' sample_urls = [ 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516202041290750678_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517952678291002259_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517613156287982372_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550504826285085169_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102020056/102020056_NEG573214944495746747_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102019124/102019124_NEG575999034503104669_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102043552/102043552_NEG550052705286344755_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102158585/102158585_2699108800648101578_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102018213/102018213_NEG593362797514630636_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520447976287936395_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG516572068288373807_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG519801669290994144_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520376387292492930_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG520535322288193494_gz_arcsinh_vis_y.jpg', 'https://storage.googleapis.com/zootasks_test_us/euclid/q1_v5/cutouts_jpg_gz_arcsinh_vis_y/102042913/102042913_NEG517834993290933807_gz_arcsinh_vis_y.jpg' ] with gr.Blocks() as demo: gr.Markdown(""" # Euclid Masked Autoencoder Demo Masked Autoencoders (MAEs) are self-supervised learning model trained to reconstruct missing parts of input data. The MAE used here is trained on 3 million images of galaxies from Euclid RR2. Select an image, upload your own image, or provide an image URL, and see how well the MAE can reconstruct the missing parts of the image. Adjust the masking ratio to see how it affects the reconstruction quality. The MAE can often do well at 90% masking or more! The model is available on [Hugging Face](https://huggingface.co/mwalmsley/euclid-rr2-mae) along with code to apply it locally. For more details, see the workshop paper: Re-envisioning Euclid Galaxy Morphology: Identifying and Interpreting Features with Sparse Autoencoders, Wu & Walmsley, 2025, NeurIPS ML4Phys workshop. """) with gr.Row(): image_output = gr.Image(type="pil", label="Original Image", height=224, width=224) masked_output = gr.Image(type="pil", label="Masked Image", height=224, width=224) reconstructed_output = gr.Image(type="pil", label="Reconstructed Image", height=224, width=224) outputs = [image_output, masked_output, reconstructed_output] with gr.Tab('Examples'): with gr.Row(): # wraps if needed sample_images = [gr.Image(value=url, interactive=False, show_download_button=False, show_label=False, width=100, height=100) for url in sample_urls] with gr.Row(): mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio", min_width=400, scale=0) explain = gr.Markdown("Click image to update") for im in sample_images: im.select(predict, [im, mask_ratio_slider], outputs) # mask ratio slider does nothing, requires image click, since we don't know which image to use otherwise with gr.Tab("Image"): image_input = gr.Image(type="pil", label="Input Image") with gr.Row(): mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio") slider_go = gr.Button("Update") image_input.change(predict, [image_input, mask_ratio_slider], outputs) mask_ratio_slider.change(predict, [image_input, mask_ratio_slider], outputs) slider_go.click(predict, [image_input, mask_ratio_slider], outputs) gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/) and on [Hugging Face](https://huggingface.co/mwalmsley/).') with gr.Tab("URL"): url_input = gr.Textbox(value=default_url, label="Image URL") with gr.Row(): mask_ratio_slider = gr.Slider(minimum=0.0, maximum=0.99, value=0.9, step=0.01, label="Mask Ratio") slider_go = gr.Button("Update") gr.Markdown('The model was trained to work well on images from Euclid, with Galaxy Zoo style preprocessing. These images are shared at [euclid.streamlit.app](https://euclid.streamlit.app/). You can get URLs by right-clicking each image.') url_input.change(predict, [url_input, mask_ratio_slider], outputs) mask_ratio_slider.change(predict, [url_input, mask_ratio_slider], outputs) slider_go.click(predict, [url_input, mask_ratio_slider], outputs) gr.Markdown( """ --- *Walmsley trained the model and Wu ran the sparsity analysis. Additional thanks to Inigo Val Slijepcevic, Micah Bowles, Devina Mohan, Anna Scaife, and Joshua Speagle, for their help and advice. We are grateful to the Euclid Consortium and the European Space Agency for making the data available.* """) demo.launch()