Spaces:
Runtime error
Runtime error
File size: 2,523 Bytes
8c2e0d0 ee7158f 49c2af2 ee7158f 4e717d6 24ec8bb 4e717d6 8c2e0d0 4e717d6 9544e60 ce8b907 674e245 7c48676 4e717d6 9544e60 4e717d6 50d5e76 9544e60 ce8b907 9544e60 674e245 9544e60 060b6b6 c29a475 674e245 4f27510 24ec8bb 4e717d6 856d03c 24ec8bb a3d55a6 bb62642 24ec8bb f742b99 9544e60 24ec8bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
"""
"""
# Upgrade PyTorch
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
# Actual app.py
import os
from datetime import datetime
import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from zerogpu import aoti_compile
pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
@spaces.GPU(duration=1500)
def compile_transformer():
pipeline.transformer.fuse_qkv_projections()
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
def _example_tensor(*shape):
return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
seq_length = 256 if is_timestep_distilled else 512
transformer_kwargs = {
'hidden_states': _example_tensor(1, 4096, 64),
'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
'pooled_projections': _example_tensor(1, 768),
'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
'txt_ids': _example_tensor(seq_length, 3),
'img_ids': _example_tensor(4096, 3),
'joint_attention_kwargs': {},
'return_dict': False,
}
inductor_configs = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
return aoti_compile(exported, inductor_configs)
transformer_config = pipeline.transformer.config
pipeline.transformer = compile_transformer()
pipeline.transformer.config = transformer_config
@spaces.GPU
def generate_image(prompt: str):
t0 = datetime.now()
images = []
for _ in range(9):
image = pipeline(prompt, num_inference_steps=4).images[0]
elapsed = -(t0 - (t0 := datetime.now()))
images += [(image, f'{elapsed.total_seconds():.2f}s')]
yield images
gr.Interface(generate_image, gr.Text(), gr.Gallery(rows=3, columns=3, height='60vh')).launch()
|