Linoy Tsaban
commited on
Commit
·
6a5a59b
1
Parent(s):
af56f98
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,12 +14,6 @@ import re
|
|
| 14 |
|
| 15 |
|
| 16 |
|
| 17 |
-
def randomize_seed_fn(seed, randomize_seed):
|
| 18 |
-
if randomize_seed:
|
| 19 |
-
seed = random.randint(0, np.iinfo(np.int32).max)
|
| 20 |
-
torch.manual_seed(seed)
|
| 21 |
-
return seed
|
| 22 |
-
|
| 23 |
|
| 24 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
| 25 |
|
|
@@ -116,8 +110,29 @@ def get_example():
|
|
| 116 |
]]
|
| 117 |
return case
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
input_image,
|
| 122 |
do_inversion,
|
| 123 |
seed, randomize_seed,
|
|
@@ -127,7 +142,7 @@ def invert_and_reconstruct(
|
|
| 127 |
steps=100,
|
| 128 |
src_cfg_scale = 3.5,
|
| 129 |
skip=36,
|
| 130 |
-
tar_cfg_scale=15
|
| 131 |
|
| 132 |
):
|
| 133 |
|
|
@@ -140,10 +155,7 @@ def invert_and_reconstruct(
|
|
| 140 |
wts = gr.State(value=wts_tensor)
|
| 141 |
zs = gr.State(value=zs_tensor)
|
| 142 |
do_inversion = False
|
| 143 |
-
|
| 144 |
-
# output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 145 |
-
|
| 146 |
-
# return output, wts, zs, do_inversion
|
| 147 |
return wts, zs, do_inversion
|
| 148 |
|
| 149 |
|
|
@@ -244,7 +256,10 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 244 |
else:
|
| 245 |
return row2.update(visible=True), row3.update(visible=True), plus.update(visible=False), 3
|
| 246 |
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
| 248 |
def reset_do_inversion():
|
| 249 |
do_inversion = True
|
| 250 |
return do_inversion
|
|
@@ -255,15 +270,16 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 255 |
zs = gr.State()
|
| 256 |
do_inversion = gr.State(value=True)
|
| 257 |
sega_concepts_counter = gr.State(1)
|
|
|
|
| 258 |
|
| 259 |
|
| 260 |
|
| 261 |
with gr.Row():
|
| 262 |
input_image = gr.Image(label="Input Image", interactive=True)
|
| 263 |
-
|
| 264 |
sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
|
| 265 |
input_image.style(height=365, width=365)
|
| 266 |
-
|
| 267 |
sega_edited_image.style(height=365, width=365)
|
| 268 |
|
| 269 |
with gr.Tabs() as tabs:
|
|
@@ -322,12 +338,13 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 322 |
)
|
| 323 |
|
| 324 |
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
| 325 |
-
|
| 326 |
|
| 327 |
|
| 328 |
with gr.Row():
|
| 329 |
with gr.Column(scale=1, min_width=100):
|
| 330 |
run_button = gr.Button("Run")
|
|
|
|
| 331 |
# with gr.Column(scale=1, min_width=100):
|
| 332 |
# edit_button = gr.Button("Edit")
|
| 333 |
|
|
@@ -350,16 +367,25 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 350 |
|
| 351 |
|
| 352 |
|
| 353 |
-
|
| 354 |
outputs= [row2, row3, plus, sega_concepts_counter], queue = False)
|
| 355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
run_button.click(
|
| 358 |
fn = randomize_seed_fn,
|
| 359 |
inputs = [seed, randomize_seed],
|
| 360 |
outputs = [seed],
|
| 361 |
queue = False).then(
|
| 362 |
-
fn=
|
| 363 |
inputs=[input_image,
|
| 364 |
do_inversion,
|
| 365 |
seed, randomize_seed,
|
|
@@ -369,10 +395,10 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 369 |
steps,
|
| 370 |
src_cfg_scale,
|
| 371 |
skip,
|
| 372 |
-
tar_cfg_scale
|
| 373 |
],
|
| 374 |
-
# outputs=[ddpm_edited_image, wts, zs, do_inversion],
|
| 375 |
outputs=[wts, zs, do_inversion],
|
|
|
|
| 376 |
).success(
|
| 377 |
fn=edit,
|
| 378 |
inputs=[input_image,
|
|
@@ -389,8 +415,17 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 389 |
|
| 390 |
],
|
| 391 |
outputs=[sega_edited_image],
|
|
|
|
|
|
|
|
|
|
| 392 |
)
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
# Automatically start inverting upon input_image change
|
| 395 |
input_image.change(
|
| 396 |
fn = reset_do_inversion,
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
| 19 |
|
|
|
|
| 110 |
]]
|
| 111 |
return case
|
| 112 |
|
| 113 |
+
def randomize_seed_fn(seed, randomize_seed):
|
| 114 |
+
if randomize_seed:
|
| 115 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
| 116 |
+
torch.manual_seed(seed)
|
| 117 |
+
return seed
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
|
| 122 |
+
def reconstruct(tar_prompt,
|
| 123 |
+
tar_cfg_scale,
|
| 124 |
+
skip,
|
| 125 |
+
wts, zs,
|
| 126 |
+
# do_reconstruction,
|
| 127 |
+
# reconstruction
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
):
|
| 131 |
+
# if do_reconstruction:
|
| 132 |
+
reconstruction = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 133 |
+
return reconstruction
|
| 134 |
+
|
| 135 |
+
def load_and_invert(
|
| 136 |
input_image,
|
| 137 |
do_inversion,
|
| 138 |
seed, randomize_seed,
|
|
|
|
| 142 |
steps=100,
|
| 143 |
src_cfg_scale = 3.5,
|
| 144 |
skip=36,
|
| 145 |
+
tar_cfg_scale=15
|
| 146 |
|
| 147 |
):
|
| 148 |
|
|
|
|
| 155 |
wts = gr.State(value=wts_tensor)
|
| 156 |
zs = gr.State(value=zs_tensor)
|
| 157 |
do_inversion = False
|
| 158 |
+
|
|
|
|
|
|
|
|
|
|
| 159 |
return wts, zs, do_inversion
|
| 160 |
|
| 161 |
|
|
|
|
| 256 |
else:
|
| 257 |
return row2.update(visible=True), row3.update(visible=True), plus.update(visible=False), 3
|
| 258 |
|
| 259 |
+
def show_reconstruction_option():
|
| 260 |
+
return reconstruct_button.update(visible=True)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
def reset_do_inversion():
|
| 264 |
do_inversion = True
|
| 265 |
return do_inversion
|
|
|
|
| 270 |
zs = gr.State()
|
| 271 |
do_inversion = gr.State(value=True)
|
| 272 |
sega_concepts_counter = gr.State(1)
|
| 273 |
+
# reconstruction = gr.State()
|
| 274 |
|
| 275 |
|
| 276 |
|
| 277 |
with gr.Row():
|
| 278 |
input_image = gr.Image(label="Input Image", interactive=True)
|
| 279 |
+
ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False, visible=False)
|
| 280 |
sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
|
| 281 |
input_image.style(height=365, width=365)
|
| 282 |
+
ddpm_edited_image.style(height=512, width=512)
|
| 283 |
sega_edited_image.style(height=365, width=365)
|
| 284 |
|
| 285 |
with gr.Tabs() as tabs:
|
|
|
|
| 338 |
)
|
| 339 |
|
| 340 |
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
| 341 |
+
add_concept_button = gr.Button("+")
|
| 342 |
|
| 343 |
|
| 344 |
with gr.Row():
|
| 345 |
with gr.Column(scale=1, min_width=100):
|
| 346 |
run_button = gr.Button("Run")
|
| 347 |
+
reconstruct_button = gr.Button("Show me the reconstruction")
|
| 348 |
# with gr.Column(scale=1, min_width=100):
|
| 349 |
# edit_button = gr.Button("Edit")
|
| 350 |
|
|
|
|
| 367 |
|
| 368 |
|
| 369 |
|
| 370 |
+
add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
|
| 371 |
outputs= [row2, row3, plus, sega_concepts_counter], queue = False)
|
| 372 |
|
| 373 |
+
reconstruct_button.click(
|
| 374 |
+
fn = reconstruct,
|
| 375 |
+
inputs = [tar_prompt,
|
| 376 |
+
tar_cfg_scale,
|
| 377 |
+
skip,
|
| 378 |
+
wts, zs]
|
| 379 |
+
outputs = [ddpm_edited_image]
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
|
| 383 |
run_button.click(
|
| 384 |
fn = randomize_seed_fn,
|
| 385 |
inputs = [seed, randomize_seed],
|
| 386 |
outputs = [seed],
|
| 387 |
queue = False).then(
|
| 388 |
+
fn=load_and_invert,
|
| 389 |
inputs=[input_image,
|
| 390 |
do_inversion,
|
| 391 |
seed, randomize_seed,
|
|
|
|
| 395 |
steps,
|
| 396 |
src_cfg_scale,
|
| 397 |
skip,
|
| 398 |
+
tar_cfg_scale
|
| 399 |
],
|
|
|
|
| 400 |
outputs=[wts, zs, do_inversion],
|
| 401 |
+
|
| 402 |
).success(
|
| 403 |
fn=edit,
|
| 404 |
inputs=[input_image,
|
|
|
|
| 415 |
|
| 416 |
],
|
| 417 |
outputs=[sega_edited_image],
|
| 418 |
+
).success(
|
| 419 |
+
fn = show_reconstruction_option,
|
| 420 |
+
outputs = [reconstruct_button]
|
| 421 |
)
|
| 422 |
|
| 423 |
+
reconstruct_button.click(
|
| 424 |
+
fn =
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
|
| 429 |
# Automatically start inverting upon input_image change
|
| 430 |
input_image.change(
|
| 431 |
fn = reset_do_inversion,
|