## Unpacking SDXL Turbo
### Interpreting Text-to-Image Models with Sparse Autoencoders

This colab prepares and launches the demo Gradio app for SAE's features discovery

### Cloning code and installing dependencies

In [1]:
%load_ext gradio
%load_ext autoreload
%autoreload 2

### Loading SDXL Turbo and SAEs and Feature Retriever

In [6]:
import os, sys

sys.path.insert(0, os.getcwd())
import gradio as gr
import torch
from SDLens import HookedStableDiffusionXLPipeline
from SAE import SparseAutoencoder
from app import create_demo
assert torch.cuda.is_available(), "Your machine has no access to GPU. If you are using Colab, consider changing environment"

ModuleNotFoundError: No module named 'clip_retrieval'

In [5]:
sys.path

['/mnt/dlabscratch1/anmari/sdxl-unbox',
 '/mnt/dlabscratch1/anmari/sdxl-unbox',
 '/dlabscratch1/anmari/diffusion-interpretability',
 '/dlabscratch1/anmari/anmari_env/lib/python310.zip',
 '/dlabscratch1/anmari/anmari_env/lib/python3.10',
 '/dlabscratch1/anmari/anmari_env/lib/python3.10/lib-dynload',
 '',
 '/dlabscratch1/anmari/anmari_env/lib/python3.10/site-packages',
 '/dlabscratch1/anmari/anmari_env/lib/python3.10/site-packages/setuptools/_vendor',
 '/tmp/tmpdq82a549']

In [14]:
# The SAEs were trained to work with torch.float32, but they can also work with torch.float16
# Change this value to torch.float32 if you have access to a GPU with >30GB of memory
dtype=torch.float32

In [5]:
sdxl_turbo = "stabilityai/sdxl-turbo" 
sdxl = "stabilityai/stable-diffusion-xl-base-1.0"
model_name = sdxl_turbo 

pipe = HookedStableDiffusionXLPipeline.from_pretrained(
    model_name,
    torch_dtype=dtype,
    device_map="balanced",
    variant=("fp16" if dtype==torch.float16 else None)
)
pipe.set_progress_bar_config(disable=True)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [15]:
path_to_checkpoints = './checkpoints/'

code_to_block = {
    "down.2.1": "unet.down_blocks.2.attentions.1",
    "mid.0": "unet.mid_block.attentions.0",
    "up.0.1": "unet.up_blocks.0.attentions.1",
    "up.0.0": "unet.up_blocks.0.attentions.0"
}

saes_dict = {}
means_dict = {}

for code, block in code_to_block.items():
    sae = SparseAutoencoder.load_from_disk(
        os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final"),
    )
    means = torch.load(
        os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final", "mean.pt"),
        weights_only=True
    )
    saes_dict[code] = sae.to('cuda', dtype=dtype)
    means_dict[code] = means.to('cuda', dtype=dtype)

### Launching Demo
Once demo is running, you can also use a pullic link in the output of the cell below.

Interesting features to look at:
- `down.2.1 #4998`: cartoon feature
- `down.2.1 #230:`: "fury" feature
- `down.2.1 #89`: "muscleman" feature
- `down.2.1 #4074`: anime feature
- `up.0.1   #4977`: tiger stripes
- `up.0.1   #90`: fur
- `up.0.1   #2165`: twilight blur

In [17]:
demo = create_demo(pipe, saes_dict, means_dict, use_retrieval=True)
demo.launch(share=True, height=1200)

TypeError: create_demo() got an unexpected keyword argument 'use_retrieval'

In [12]:
create_demo.__globals__

{'__name__': 'app',
 '__doc__': None,
 '__package__': '',
 '__loader__': <_frozen_importlib_external.SourceFileLoader at 0x74054caffd30>,
 '__spec__': ModuleSpec(name='app', loader=<_frozen_importlib_external.SourceFileLoader object at 0x74054caffd30>, origin='/dlabscratch1/anmari/diffusion-interpretability/app.py'),
 '__file__': '/dlabscratch1/anmari/diffusion-interpretability/app.py',
 '__cached__': '/dlabscratch1/anmari/diffusion-interpretability/__pycache__/app.cpython-310.pyc',
 '__builtins__': {'__name__': 'builtins',
  '__doc__': "Built-in functions, exceptions, and other objects.\n\nNoteworthy: None is the `nil' object; Ellipsis represents `...' in slices.",
  '__package__': '',
  '__loader__': _frozen_importlib.BuiltinImporter,
  '__spec__': ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>, origin='built-in'),
  '__build_class__': <function __build_class__>,
  '__import__': <function __import__>,
  'abs': <function abs(x, /)>,
  'all': <function a

In [8]:
#demo.close()

## Citation

If you find this notebook useful in your research, please cite our paper:

```bibtex
@misc{surkov2024unpackingsdxlturbointerpreting,
      title={Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders},
      author={Viacheslav Surkov and Chris Wendler and Mikhail Terekhov and Justin Deschenaux and Robert West and Caglar Gulcehre},
      year={2024},
      eprint={2410.22366},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2410.22366},
}
```