File size: 2,092 Bytes
2cc8fc5
cc0b502
2cc8fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from PIL import Image
import random

from wfControl.src.flux.condition import Condition
from wfControl.src.flux.generate import generate, seed_everything

print("Loading model...")
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")

pipe.unload_lora_weights()

pipe.load_lora_weights("Yuanshi/OminiControlArt", weight_name="v0/ghibli.safetensors", adapter_name="ghibli")
pipe.load_lora_weights("Yuanshi/OminiControlArt", weight_name="v0/irasutoya.safetensors", adapter_name="irasutoya")
pipe.load_lora_weights("Yuanshi/OminiControlArt", weight_name="v0/simpsons.safetensors", adapter_name="simpsons")
pipe.load_lora_weights("Yuanshi/OminiControlArt", weight_name="v0/snoopy.safetensors", adapter_name="snoopy")

def generate_image(image, style, prompt):
    def resize(img, factor=16):
        w, h = img.size
        new_w, new_h = w // factor * factor, h // factor * factor
        padding_w, padding_h = (w - new_w) // 2, (h - new_h) // 2
        img = img.crop((padding_w, padding_h, new_w + padding_w, new_h + padding_h))
        return img

    adapter_name = {
        "Studio Ghibli": "ghibli",
        "Irasutoya Illustration": "irasutoya",
        "The Simpsons": "simpsons",
        "Snoopy": "snoopy",
    }.get(style, "ghibli")
    pipe.set_adapters(adapter_name)

    factor = 512 / max(image.size)
    image = resize(
        image.resize(
            (int(image.size[0] * factor), int(image.size[1] * factor)),
            Image.LANCZOS,
        )
    )
    delta = -image.size[0] // 16
    condition = Condition("subject", image, position_delta=(0, delta))

    seed = random.randint(0, 2**32 - 1)
    seed_everything(seed)

    result_img = generate(
        pipe,
        prompt=prompt,
        conditions=[condition],
        num_inference_steps=20,
        width=640,
        height=640,
        image_guidance_scale=1.0,
        default_lora=True,
        max_sequence_length=32,
    ).images[0]

    return result_img