File size: 2,523 Bytes
8c2e0d0
 
ee7158f
 
49c2af2
ee7158f
4e717d6
 
24ec8bb
4e717d6
 
8c2e0d0
4e717d6
9544e60
ce8b907
 
674e245
7c48676
4e717d6
 
9544e60
4e717d6
 
50d5e76
9544e60
 
ce8b907
 
 
9544e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674e245
9544e60
060b6b6
c29a475
674e245
 
4f27510
24ec8bb
4e717d6
856d03c
24ec8bb
a3d55a6
bb62642
24ec8bb
 
 
f742b99
9544e60
 
24ec8bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
"""
# Upgrade PyTorch
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')

# Actual app.py
import os
from datetime import datetime

import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig

from zerogpu import aoti_compile


pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')


@spaces.GPU(duration=1500)
def compile_transformer():

    pipeline.transformer.fuse_qkv_projections()
    quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())

    def _example_tensor(*shape):
        return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)

    is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
    seq_length = 256 if is_timestep_distilled else 512

    transformer_kwargs = {
        'hidden_states': _example_tensor(1, 4096, 64),
        'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
        'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
        'pooled_projections': _example_tensor(1, 768),
        'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
        'txt_ids': _example_tensor(seq_length, 3),
        'img_ids': _example_tensor(4096, 3),
        'joint_attention_kwargs': {},
        'return_dict': False,
    }

    inductor_configs = {
        'conv_1x1_as_mm': True,
        'epilogue_fusion': False,
        'coordinate_descent_tuning': True,
        'coordinate_descent_check_all_directions': True,
        'max_autotune': True,
        'triton.cudagraphs': True,
    }

    exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)

    return aoti_compile(exported, inductor_configs)


transformer_config = pipeline.transformer.config
pipeline.transformer = compile_transformer()
pipeline.transformer.config = transformer_config


@spaces.GPU
def generate_image(prompt: str):
    t0 = datetime.now()
    images = []
    for _ in range(9):
        image = pipeline(prompt, num_inference_steps=4).images[0]
        elapsed = -(t0 - (t0 := datetime.now()))
        images += [(image, f'{elapsed.total_seconds():.2f}s')]
        yield images


gr.Interface(generate_image, gr.Text(), gr.Gallery(rows=3, columns=3, height='60vh')).launch()