Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from transformers import T5Tokenizer, T5EncoderModel | |
| from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| from two_stream_shunt_adapter import TwoStreamShuntAdapter | |
| from configs import T5_SHUNT_REPOS | |
| # βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ | |
| # Don't initialize CUDA here for ZeroGPU compatibility | |
| device = None # Will be set inside the GPU function | |
| dtype = torch.float16 | |
| # Don't load models here - will load inside GPU function | |
| t5_tok = None | |
| t5_mod = None | |
| pipe = None | |
| # Available schedulers | |
| SCHEDULERS = { | |
| "DPM++ 2M": DPMSolverMultistepScheduler, | |
| "DDIM": DDIMScheduler, | |
| "Euler": EulerDiscreteScheduler, | |
| } | |
| # βββ Adapter Configs ββββββββββββββββββββββββββββββββββββββββββ | |
| clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] | |
| clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] | |
| repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] | |
| repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] | |
| config_l = T5_SHUNT_REPOS["clip_l"]["config"] | |
| config_g = T5_SHUNT_REPOS["clip_g"]["config"] | |
| # βββ Loader βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from safetensors.torch import safe_open | |
| def load_adapter(repo, filename, config): | |
| # Don't initialize device here | |
| path = hf_hub_download(repo_id=repo, filename=filename) | |
| model = TwoStreamShuntAdapter(config).eval() | |
| tensors = {} | |
| with safe_open(path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| tensors[key] = f.get_tensor(key) | |
| model.load_state_dict(tensors) | |
| # Device will be set when called from GPU function | |
| return model | |
| # βββ Visualization ββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_heat(mat, title): | |
| import io | |
| fig, ax = plt.subplots(figsize=(6, 3), dpi=100) | |
| im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper") | |
| ax.set_title(title) | |
| plt.colorbar(im, ax=ax) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close(fig) | |
| return buf | |
| # βββ SDXL Text Encoding βββββββββββββββββββββββββββββββββββββββ | |
| def encode_sdxl_prompt(prompt, negative_prompt=""): | |
| """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders""" | |
| # Tokenize for both encoders | |
| tokens_l = pipe.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).input_ids.to(device) | |
| tokens_g = pipe.tokenizer_2( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).input_ids.to(device) | |
| # Negative prompts | |
| neg_tokens_l = pipe.tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).input_ids.to(device) | |
| neg_tokens_g = pipe.tokenizer_2( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).input_ids.to(device) | |
| with torch.no_grad(): | |
| # CLIP-L embeddings (768d) - works fine | |
| clip_l_embeds = pipe.text_encoder(tokens_l)[0] | |
| neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0] | |
| # CLIP-G embeddings (1280d) - [0] is pooled, [1] is sequence (opposite of CLIP-L) | |
| clip_g_output = pipe.text_encoder_2(tokens_g) | |
| clip_g_embeds = clip_g_output[1] # sequence embeddings | |
| neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g) | |
| neg_clip_g_embeds = neg_clip_g_output[1] # sequence embeddings | |
| # Pooled embeddings for SDXL | |
| pooled_embeds = clip_g_output[0] # pooled embeddings | |
| neg_pooled_embeds = neg_clip_g_output[0] # pooled embeddings | |
| return { | |
| "clip_l": clip_l_embeds, | |
| "clip_g": clip_g_embeds, | |
| "neg_clip_l": neg_clip_l_embeds, | |
| "neg_clip_g": neg_clip_g_embeds, | |
| "pooled": pooled_embeds, | |
| "neg_pooled": neg_pooled_embeds | |
| } | |
| # βββ Inference ββββββββββββββββββββββββββββββββββββββββββββ | |
| def infer( | |
| prompt, negative_prompt, adapter_l_file, adapter_g_file, | |
| strength, noise, gate_prob, use_anchor, | |
| steps, cfg_scale, scheduler_name, | |
| width, height, seed | |
| ): | |
| import torch | |
| import numpy as np | |
| global t5_tok, t5_mod, pipe | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| with torch.no_grad(): | |
| # Initialize tokenizer and model | |
| if t5_tok is None: | |
| t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
| t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() | |
| if pipe is None: | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=dtype, | |
| variant="fp16", | |
| use_safetensors=True | |
| ).to(device) | |
| # Reproducibility | |
| if seed != -1: | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| # Scheduler | |
| if scheduler_name in SCHEDULERS: | |
| pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config) | |
| # T5 embeddings | |
| t5_ids = t5_tok( | |
| prompt, return_tensors="pt", | |
| padding="max_length", max_length=77, truncation=True | |
| ).input_ids.to(device) | |
| t5_seq = t5_mod(t5_ids).last_hidden_state | |
| # CLIP embeddings | |
| clip_embeds = encode_sdxl_prompt(prompt, negative_prompt) | |
| # Debug shapes | |
| print(f"T5 seq shape: {t5_seq.shape}") | |
| print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}") | |
| print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}") | |
| # Load adapters | |
| adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None | |
| adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None | |
| # ---- Adapter L ---- | |
| if adapter_l: | |
| anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"]) | |
| gate_l_scaled = gate_l * gate_prob | |
| delta_l_final = delta_l * strength * gate_l_scaled | |
| clip_l_mod = clip_embeds["clip_l"] + delta_l_final | |
| if use_anchor: | |
| clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled | |
| if noise > 0: | |
| clip_l_mod += torch.randn_like(clip_l_mod) * noise | |
| else: | |
| clip_l_mod = clip_embeds["clip_l"] | |
| delta_l_final = torch.zeros_like(clip_l_mod) | |
| gate_l_scaled = torch.zeros_like(clip_l_mod) | |
| g_pred_l = torch.tensor(0.0) | |
| tau_l = torch.tensor(0.0) | |
| # ---- Adapter G ---- | |
| if adapter_g: | |
| anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"]) | |
| gate_g_scaled = gate_g * gate_prob | |
| delta_g_final = delta_g * strength * gate_g_scaled | |
| clip_g_mod = clip_embeds["clip_g"] + delta_g_final | |
| if use_anchor: | |
| clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled | |
| if noise > 0: | |
| clip_g_mod += torch.randn_like(clip_g_mod) * noise | |
| else: | |
| clip_g_mod = clip_embeds["clip_g"] | |
| delta_g_final = torch.zeros_like(clip_g_mod) | |
| gate_g_scaled = torch.zeros_like(clip_g_mod) | |
| g_pred_g = torch.tensor(0.0) | |
| tau_g = torch.tensor(0.0) | |
| # ---- Combine embeddings ---- | |
| prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype) | |
| neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype) | |
| # ---- Generate image ---- | |
| generator = torch.Generator(device=device).manual_seed(seed) if seed != -1 else None | |
| image = pipe( | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=clip_embeds["pooled"], | |
| negative_prompt_embeds=neg_embeds, | |
| negative_pooled_prompt_embeds=clip_embeds["neg_pooled"], | |
| num_inference_steps=steps, | |
| guidance_scale=cfg_scale, | |
| width=width, | |
| height=height, | |
| num_images_per_prompt=1, | |
| generator=generator, | |
| ).images[0] | |
| return ( | |
| image, | |
| plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ CLIP-L"), | |
| plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"), | |
| plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ CLIP-G"), | |
| plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"), | |
| f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}", | |
| f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}" | |
| ) | |
| # βββ Gradio Interface βββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π§ SDXL Dual Shunt Adapter β’ T5βCLIP Enhancement") | |
| gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Prompts | |
| with gr.Group(): | |
| gr.Markdown("### Prompts") | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="a futuristic control station with holographic displays", | |
| lines=3 | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="blurry, low quality, distorted", | |
| lines=2 | |
| ) | |
| # Adapters | |
| with gr.Group(): | |
| gr.Markdown("### Adapters") | |
| adapter_l = gr.Dropdown( | |
| choices=["None"] + clip_l_opts, | |
| label="CLIP-L (768d) Adapter", | |
| value="None" | |
| ) | |
| adapter_g = gr.Dropdown( | |
| choices=["None"] + clip_g_opts, | |
| label="CLIP-G (1280d) Adapter", | |
| value="None" | |
| ) | |
| # Adapter Controls | |
| with gr.Group(): | |
| gr.Markdown("### Adapter Controls") | |
| strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") | |
| noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") | |
| gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") | |
| use_anchor = gr.Checkbox(label="Use Anchor", value=True) | |
| # Generation Settings | |
| with gr.Group(): | |
| gr.Markdown("### Generation Settings") | |
| with gr.Row(): | |
| steps = gr.Slider(1, 100, value=25, step=1, label="Steps") | |
| cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale") | |
| scheduler_name = gr.Dropdown( | |
| choices=list(SCHEDULERS.keys()), | |
| value="DPM++ 2M", | |
| label="Scheduler" | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider(512, 1536, value=1024, step=64, label="Width") | |
| height = gr.Slider(512, 1536, value=1024, step=64, label="Height") | |
| seed = gr.Number(value=-1, label="Seed (-1 for random)") | |
| run_btn = gr.Button("π Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| # Output | |
| with gr.Group(): | |
| gr.Markdown("### Generated Image") | |
| out_img = gr.Image(label="Result", height=400) | |
| # Visualizations | |
| with gr.Group(): | |
| gr.Markdown("### Adapter Visualizations") | |
| with gr.Row(): | |
| delta_l = gr.Image(label="Ξ CLIP-L", height=200) | |
| gate_l = gr.Image(label="Gate CLIP-L", height=200) | |
| with gr.Row(): | |
| delta_g = gr.Image(label="Ξ CLIP-G", height=200) | |
| gate_g = gr.Image(label="Gate CLIP-G", height=200) | |
| # Stats | |
| with gr.Group(): | |
| gr.Markdown("### Adapter Statistics") | |
| stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False) | |
| stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False) | |
| # Event handlers | |
| def process_adapters(adapter_l_val, adapter_g_val): | |
| # Convert "None" back to None for processing | |
| adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val | |
| adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val | |
| return adapter_l_processed, adapter_g_processed | |
| def run_inference(*args): | |
| # Process adapter selections | |
| adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3]) | |
| # Call inference with processed adapters | |
| new_args = list(args) | |
| new_args[2] = adapter_l_processed | |
| new_args[3] = adapter_g_processed | |
| return infer(*new_args) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[ | |
| prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob, | |
| use_anchor, steps, cfg_scale, scheduler_name, width, height, seed | |
| ], | |
| outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |