Spaces:
Paused
Paused
| # !pip install diffusers | |
| import torch | |
| from diffusers import DDIMPipeline, DDPMPipeline, PNDMPipeline | |
| from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler | |
| from diffusers import UNetUnconditionalModel | |
| import gradio as gr | |
| import PIL.Image | |
| import numpy as np | |
| import random | |
| model_id = "google/ddpm-celebahq-256" | |
| model = UNetUnconditionalModel.from_pretrained(model_id, subfolder="unet") | |
| # load model and scheduler | |
| ddpm_scheduler = DDPMScheduler.from_config(model_id, subfolder="scheduler") | |
| ddpm_pipeline = DDPMPipeline(unet=model, scheduler=ddpm_scheduler) | |
| ddim_scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler") | |
| ddim_pipeline = DDIMPipeline(unet=model, scheduler=ddim_scheduler) | |
| pndm_scheduler = PNDMScheduler.from_config(model_id, subfolder="scheduler") | |
| pndm_pipeline = PNDMPipeline(unet=model, scheduler=pndm_scheduler) | |
| # run pipeline in inference (sample random noise and denoise) | |
| def predict(seed=42,scheduler="ddim"): | |
| torch.cuda.empty_cache() | |
| generator = torch.manual_seed(seed) | |
| if(scheduler == "ddim"): | |
| image = ddim_pipeline(generator=generator, num_inference_steps=100) | |
| image = image["sample"] | |
| elif(scheduler == "ddpm"): | |
| image = ddpm_pipeline(generator=generator) | |
| #["sample"] doesnt work here for some reason | |
| elif(scheduler == "pndm"): | |
| image = pndm_pipeline(generator=generator, num_inference_steps=11) | |
| #["sample"] doesnt work here for some reason | |
| image_processed = image.cpu().permute(0, 2, 3, 1) | |
| if scheduler == "pndm": | |
| image_processed = (image_processed + 1.0) / 2 | |
| image_processed = torch.clamp(image_processed, 0.0, 1.0) | |
| image_processed = image_processed * 255 | |
| else: | |
| image_processed = (image_processed + 1.0) * 127.5 | |
| image_processed = image_processed.detach().numpy().astype(np.uint8) | |
| return(PIL.Image.fromarray(image_processed[0])) | |
| random_seed = random.randint(0, 2147483647) | |
| gr.Interface( | |
| predict, | |
| inputs=[ | |
| #gr.inputs.Slider(1, 1000, label='Inference Steps', default=20, step=1), | |
| gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed), | |
| gr.inputs.Radio(["ddim", "ddpm", "pndm"], default="ddpm",label="Diffusion scheduler") | |
| ], | |
| outputs=gr.Image(shape=[256,256], type="pil"), | |
| ).launch() |