Spaces:
Runtime error
Runtime error
Commit
·
5156e7a
1
Parent(s):
3f98781
Working version
Browse files
app.py
CHANGED
|
@@ -1,44 +1,57 @@
|
|
| 1 |
# !pip install diffusers
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import PIL.Image
|
| 5 |
import numpy as np
|
| 6 |
import random
|
| 7 |
-
import torch
|
| 8 |
|
| 9 |
model_id = "google/ddpm-celebahq-256"
|
|
|
|
| 10 |
|
| 11 |
# load model and scheduler
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
|
|
|
|
|
|
|
| 16 |
# run pipeline in inference (sample random noise and denoise)
|
| 17 |
-
def predict(
|
|
|
|
| 18 |
generator = torch.manual_seed(seed)
|
| 19 |
if(scheduler == "ddim"):
|
| 20 |
-
image =
|
| 21 |
image = image["sample"]
|
| 22 |
elif(scheduler == "ddpm"):
|
| 23 |
-
image =
|
|
|
|
| 24 |
elif(scheduler == "pndm"):
|
| 25 |
-
image =
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# process image to PIL
|
| 29 |
image_processed = image.cpu().permute(0, 2, 3, 1)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
random_seed = random.randint(0, 2147483647)
|
| 36 |
gr.Interface(
|
| 37 |
predict,
|
| 38 |
inputs=[
|
| 39 |
-
gr.inputs.Slider(1, 1000, label='Inference Steps', default=
|
| 40 |
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed),
|
| 41 |
gr.inputs.Radio(["ddim", "ddpm", "pndm"], default="ddpm",label="Diffusion scheduler")
|
| 42 |
],
|
| 43 |
-
outputs="
|
| 44 |
).launch()
|
|
|
|
| 1 |
# !pip install diffusers
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import DDIMPipeline, DDPMPipeline, PNDMPipeline
|
| 4 |
+
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
|
| 5 |
+
from diffusers import UNetUnconditionalModel
|
| 6 |
import gradio as gr
|
| 7 |
import PIL.Image
|
| 8 |
import numpy as np
|
| 9 |
import random
|
|
|
|
| 10 |
|
| 11 |
model_id = "google/ddpm-celebahq-256"
|
| 12 |
+
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder="unet")
|
| 13 |
|
| 14 |
# load model and scheduler
|
| 15 |
+
ddpm_scheduler = DDPMScheduler.from_config(model_id, subfolder="scheduler")
|
| 16 |
+
ddpm_pipeline = DDPMPipeline(unet=model, scheduler=ddpm_scheduler)
|
| 17 |
+
|
| 18 |
+
ddim_scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
|
| 19 |
+
ddim_pipeline = DDIMPipeline(unet=model, scheduler=ddim_scheduler)
|
| 20 |
|
| 21 |
+
pndm_scheduler = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
| 22 |
+
pndm_pipeline = PNDMPipeline(unet=model, scheduler=pndm_scheduler)
|
| 23 |
# run pipeline in inference (sample random noise and denoise)
|
| 24 |
+
def predict(seed=42,scheduler="ddim"):
|
| 25 |
+
torch.cuda.empty_cache()
|
| 26 |
generator = torch.manual_seed(seed)
|
| 27 |
if(scheduler == "ddim"):
|
| 28 |
+
image = ddim_pipeline(generator=generator, num_inference_steps=100)
|
| 29 |
image = image["sample"]
|
| 30 |
elif(scheduler == "ddpm"):
|
| 31 |
+
image = ddpm_pipeline(generator=generator)
|
| 32 |
+
#["sample"] doesnt work here for some reason
|
| 33 |
elif(scheduler == "pndm"):
|
| 34 |
+
image = pndm_pipeline(generator=generator, num_inference_steps=11)
|
| 35 |
+
#["sample"] doesnt work here for some reason
|
| 36 |
+
|
|
|
|
| 37 |
image_processed = image.cpu().permute(0, 2, 3, 1)
|
| 38 |
+
if scheduler == "pndm":
|
| 39 |
+
image_processed = (image_processed + 1.0) / 2
|
| 40 |
+
image_processed = torch.clamp(image_processed, 0.0, 1.0)
|
| 41 |
+
image_processed = image_processed * 255
|
| 42 |
+
else:
|
| 43 |
+
image_processed = (image_processed + 1.0) * 127.5
|
| 44 |
+
image_processed = image_processed.detach().numpy().astype(np.uint8)
|
| 45 |
+
return(PIL.Image.fromarray(image_processed[0]))
|
| 46 |
|
| 47 |
|
| 48 |
random_seed = random.randint(0, 2147483647)
|
| 49 |
gr.Interface(
|
| 50 |
predict,
|
| 51 |
inputs=[
|
| 52 |
+
#gr.inputs.Slider(1, 1000, label='Inference Steps', default=20, step=1),
|
| 53 |
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed),
|
| 54 |
gr.inputs.Radio(["ddim", "ddpm", "pndm"], default="ddpm",label="Diffusion scheduler")
|
| 55 |
],
|
| 56 |
+
outputs=gr.Image(shape=[256,256], type="pil"),
|
| 57 |
).launch()
|