File size: 929 Bytes
24ec8bb
4e717d6
 
8c2e0d0
4e717d6
9544e60
674e245
69667cb
4e717d6
 
9544e60
69667cb
4f27510
24ec8bb
4e717d6
856d03c
69667cb
24ec8bb
a3d55a6
bb62642
69667cb
24ec8bb
 
f742b99
9544e60
 
da628cb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from datetime import datetime

import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline

from optimization import optimize_pipeline_


pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
optimize_pipeline_(pipeline, "prompt")


@spaces.GPU
def generate_image(prompt: str):
    generator = torch.Generator(device='cuda').manual_seed(42)
    t0 = datetime.now()
    images = []
    for _ in range(9):
        image = pipeline(prompt, num_inference_steps=4, generator=generator).images[0]
        elapsed = -(t0 - (t0 := datetime.now()))
        images += [(image, f'{elapsed.total_seconds():.2f}s')]
        yield images


gr.Interface(
    fn=generate_image,
    inputs=gr.Text(label="Prompt"),
    outputs=gr.Gallery(rows=3, columns=3, height='60vh'),
    examples=["A cat playing with a ball of yarn"],
    cache_examples=False,
).launch()