Dimitri
commited on
Commit
·
92cf4eb
1
Parent(s):
61d0d14
fix demo
Browse files
app.py
CHANGED
|
@@ -10,13 +10,33 @@ from fabric.generator import AttentionBasedGenerator
|
|
| 10 |
model_name = ""
|
| 11 |
model_ckpt = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_7_pruned.safetensors"
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
css = """
|
|
@@ -96,33 +116,44 @@ def generate_fn(
|
|
| 96 |
liked = []
|
| 97 |
disliked = disliked[-max_feedback_imgs:]
|
| 98 |
# else: keep all feedback images
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
prompt
|
| 102 |
-
negative_prompt
|
| 103 |
-
liked
|
| 104 |
-
disliked
|
| 105 |
-
denoising_steps
|
| 106 |
-
guidance_scale
|
| 107 |
-
feedback_start
|
| 108 |
-
feedback_end
|
| 109 |
-
min_weight
|
| 110 |
-
max_weight
|
| 111 |
-
neg_scale
|
| 112 |
-
seed
|
| 113 |
-
n_images
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
return [(img, f"Image {i+1}") for i, img in enumerate(images)], images
|
| 116 |
except Exception as err:
|
| 117 |
raise gr.Error(str(err))
|
| 118 |
|
| 119 |
|
| 120 |
def add_img_from_list(i, curr_imgs, all_imgs):
|
|
|
|
|
|
|
| 121 |
if i >= 0 and i < len(curr_imgs):
|
| 122 |
all_imgs.append(curr_imgs[i])
|
| 123 |
return all_imgs, all_imgs # return (gallery, state)
|
| 124 |
|
| 125 |
def add_img(img, all_imgs):
|
|
|
|
|
|
|
| 126 |
all_imgs.append(img)
|
| 127 |
return None, all_imgs, all_imgs
|
| 128 |
|
|
@@ -148,7 +179,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 148 |
with gr.Column():
|
| 149 |
denoising_steps = gr.Slider(1, 100, value=20, step=1, label="Sampling steps")
|
| 150 |
guidance_scale = gr.Slider(0.0, 30.0, value=6, step=0.25, label="CFG scale")
|
| 151 |
-
batch_size = gr.Slider(1, 10, value=4, step=1, label="Batch size")
|
| 152 |
seed = gr.Number(-1, minimum=-1, precision=0, label="Seed")
|
| 153 |
max_feedback_imgs = gr.Slider(0, 20, value=6, step=1, label="Max. feedback images", info="Maximum number of liked/disliked images to be used. If exceeded, only the most recent images will be used as feedback. (NOTE: large number of feedback imgs => high VRAM requirements)")
|
| 154 |
feedback_enabled = gr.Checkbox(True, label="Enable feedback", interactive=True)
|
|
@@ -222,8 +253,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 222 |
liked_img_input.upload(add_img, [liked_img_input, liked_imgs], [liked_img_input, like_gallery, liked_imgs], queue=False)
|
| 223 |
disliked_img_input.upload(add_img, [disliked_img_input, disliked_imgs], [disliked_img_input, dislike_gallery, disliked_imgs], queue=False)
|
| 224 |
|
| 225 |
-
clear_liked_btn.click(lambda: [
|
| 226 |
-
clear_disliked_btn.click(lambda: [
|
| 227 |
|
| 228 |
-
demo.queue(
|
| 229 |
-
demo.launch()
|
|
|
|
| 10 |
model_name = ""
|
| 11 |
model_ckpt = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_7_pruned.safetensors"
|
| 12 |
|
| 13 |
+
class GeneratorWrapper:
|
| 14 |
+
def __init__(self, model_name=None, model_ckpt=None):
|
| 15 |
+
self.model_name = model_name if model_name else None
|
| 16 |
+
self.model_ckpt = model_ckpt if model_ckpt else None
|
| 17 |
+
self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 18 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
|
| 20 |
+
self.reload()
|
| 21 |
+
|
| 22 |
+
def generate(self, *args, **kwargs):
|
| 23 |
+
return self.generator.generate(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
def to(self, device):
|
| 26 |
+
return self.generator.to(device)
|
| 27 |
+
|
| 28 |
+
def reload(self):
|
| 29 |
+
if hasattr(self, "generator"):
|
| 30 |
+
del self.generator
|
| 31 |
+
if self.device == "cuda":
|
| 32 |
+
torch.cuda.empty_cache()
|
| 33 |
+
self.generator = AttentionBasedGenerator(
|
| 34 |
+
model_name=self.model_name,
|
| 35 |
+
model_ckpt=self.model_ckpt,
|
| 36 |
+
torch_dtype=self.dtype,
|
| 37 |
+
).to(self.device)
|
| 38 |
+
|
| 39 |
+
generator = GeneratorWrapper(model_name, model_ckpt)
|
| 40 |
|
| 41 |
|
| 42 |
css = """
|
|
|
|
| 116 |
liked = []
|
| 117 |
disliked = disliked[-max_feedback_imgs:]
|
| 118 |
# else: keep all feedback images
|
| 119 |
+
|
| 120 |
+
generate_kwargs = {
|
| 121 |
+
"prompt": prompt,
|
| 122 |
+
"negative_prompt": neg_prompt,
|
| 123 |
+
"liked": liked,
|
| 124 |
+
"disliked": disliked,
|
| 125 |
+
"denoising_steps": denoising_steps,
|
| 126 |
+
"guidance_scale": guidance_scale,
|
| 127 |
+
"feedback_start": feedback_start,
|
| 128 |
+
"feedback_end": feedback_end,
|
| 129 |
+
"min_weight": min_weight,
|
| 130 |
+
"max_weight": max_weight,
|
| 131 |
+
"neg_scale": neg_scale,
|
| 132 |
+
"seed": seed,
|
| 133 |
+
"n_images": batch_size,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
images = generator.generate(**generate_kwargs)
|
| 138 |
+
except RuntimeError as err:
|
| 139 |
+
if 'out of memory' in str(err):
|
| 140 |
+
generator.reload()
|
| 141 |
+
raise
|
| 142 |
return [(img, f"Image {i+1}") for i, img in enumerate(images)], images
|
| 143 |
except Exception as err:
|
| 144 |
raise gr.Error(str(err))
|
| 145 |
|
| 146 |
|
| 147 |
def add_img_from_list(i, curr_imgs, all_imgs):
|
| 148 |
+
if all_imgs is None:
|
| 149 |
+
all_imgs = []
|
| 150 |
if i >= 0 and i < len(curr_imgs):
|
| 151 |
all_imgs.append(curr_imgs[i])
|
| 152 |
return all_imgs, all_imgs # return (gallery, state)
|
| 153 |
|
| 154 |
def add_img(img, all_imgs):
|
| 155 |
+
if all_imgs is None:
|
| 156 |
+
all_imgs = []
|
| 157 |
all_imgs.append(img)
|
| 158 |
return None, all_imgs, all_imgs
|
| 159 |
|
|
|
|
| 179 |
with gr.Column():
|
| 180 |
denoising_steps = gr.Slider(1, 100, value=20, step=1, label="Sampling steps")
|
| 181 |
guidance_scale = gr.Slider(0.0, 30.0, value=6, step=0.25, label="CFG scale")
|
| 182 |
+
batch_size = gr.Slider(1, 10, value=4, step=1, label="Batch size", interactive=False)
|
| 183 |
seed = gr.Number(-1, minimum=-1, precision=0, label="Seed")
|
| 184 |
max_feedback_imgs = gr.Slider(0, 20, value=6, step=1, label="Max. feedback images", info="Maximum number of liked/disliked images to be used. If exceeded, only the most recent images will be used as feedback. (NOTE: large number of feedback imgs => high VRAM requirements)")
|
| 185 |
feedback_enabled = gr.Checkbox(True, label="Enable feedback", interactive=True)
|
|
|
|
| 253 |
liked_img_input.upload(add_img, [liked_img_input, liked_imgs], [liked_img_input, like_gallery, liked_imgs], queue=False)
|
| 254 |
disliked_img_input.upload(add_img, [disliked_img_input, disliked_imgs], [disliked_img_input, dislike_gallery, disliked_imgs], queue=False)
|
| 255 |
|
| 256 |
+
clear_liked_btn.click(lambda: [[], []], None, [liked_imgs, like_gallery], queue=False)
|
| 257 |
+
clear_disliked_btn.click(lambda: [[], []], None, [disliked_imgs, dislike_gallery], queue=False)
|
| 258 |
|
| 259 |
+
demo.queue(1)
|
| 260 |
+
demo.launch(debug=True)
|