Linoy Tsaban
commited on
Commit
·
8b5d4bf
1
Parent(s):
d19d91b
Update app.py
Browse files
app.py
CHANGED
|
@@ -28,7 +28,7 @@ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image
|
|
| 28 |
def caption_image(input_image):
|
| 29 |
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
|
| 30 |
pixel_values = inputs.pixel_values
|
| 31 |
-
|
| 32 |
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
|
| 33 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 34 |
return generated_caption
|
|
@@ -38,9 +38,9 @@ def caption_image(input_image):
|
|
| 38 |
## DDPM INVERSION AND SAMPLING ##
|
| 39 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
| 40 |
|
| 41 |
-
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
| 42 |
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
| 43 |
-
|
| 44 |
# returns wt, zs, wts:
|
| 45 |
# wt - inverted latent
|
| 46 |
# wts - intermediate inverted latents
|
|
@@ -50,7 +50,7 @@ def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta
|
|
| 50 |
|
| 51 |
# vae encode image
|
| 52 |
with inference_mode():
|
| 53 |
-
|
| 54 |
|
| 55 |
# find Zs and wts - forward process
|
| 56 |
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
|
|
@@ -61,10 +61,10 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
|
|
| 61 |
|
| 62 |
# reverse process (via Zs and wT)
|
| 63 |
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
|
| 64 |
-
|
| 65 |
# vae decode image
|
| 66 |
with inference_mode():
|
| 67 |
-
|
| 68 |
if x0_dec.dim()<4:
|
| 69 |
x0_dec = x0_dec[None,:,:,:]
|
| 70 |
img = image_grid(x0_dec)
|
|
@@ -142,7 +142,7 @@ def edit(input_image,
|
|
| 142 |
src_cfg_scale):
|
| 143 |
|
| 144 |
if do_inversion or randomize_seed:
|
| 145 |
-
x0 = load_512(input_image, device=device)
|
| 146 |
# invert and retrieve noise maps and latent
|
| 147 |
zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
|
| 148 |
wts = gr.State(value=wts_tensor)
|
|
|
|
| 28 |
def caption_image(input_image):
|
| 29 |
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
|
| 30 |
pixel_values = inputs.pixel_values
|
| 31 |
+
|
| 32 |
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
|
| 33 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 34 |
return generated_caption
|
|
|
|
| 38 |
## DDPM INVERSION AND SAMPLING ##
|
| 39 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
| 40 |
|
| 41 |
+
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
| 42 |
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
| 43 |
+
|
| 44 |
# returns wt, zs, wts:
|
| 45 |
# wt - inverted latent
|
| 46 |
# wts - intermediate inverted latents
|
|
|
|
| 50 |
|
| 51 |
# vae encode image
|
| 52 |
with inference_mode():
|
| 53 |
+
w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215)
|
| 54 |
|
| 55 |
# find Zs and wts - forward process
|
| 56 |
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
|
|
|
|
| 61 |
|
| 62 |
# reverse process (via Zs and wT)
|
| 63 |
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
|
| 64 |
+
|
| 65 |
# vae decode image
|
| 66 |
with inference_mode():
|
| 67 |
+
x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
|
| 68 |
if x0_dec.dim()<4:
|
| 69 |
x0_dec = x0_dec[None,:,:,:]
|
| 70 |
img = image_grid(x0_dec)
|
|
|
|
| 142 |
src_cfg_scale):
|
| 143 |
|
| 144 |
if do_inversion or randomize_seed:
|
| 145 |
+
x0 = load_512(input_image, device=device).to(torch.float16)
|
| 146 |
# invert and retrieve noise maps and latent
|
| 147 |
zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
|
| 148 |
wts = gr.State(value=wts_tensor)
|