Spaces:
Running
on
Zero
Running
on
Zero
| from huggingface_hub import hf_hub_download | |
| hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".") | |
| hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".") | |
| hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".") | |
| import torch | |
| from PIL import Image | |
| from diffusers import DDPMScheduler | |
| from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler | |
| from module.ip_adapter.utils import load_adapter_to_pipe | |
| from pipelines.sdxl_instantir import InstantIRPipeline | |
| def resize_img(input_image, max_side=1280, min_side=1024, size=None, | |
| pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): | |
| w, h = input_image.size | |
| if size is not None: | |
| w_resize_new, h_resize_new = size | |
| else: | |
| # ratio = min_side / min(h, w) | |
| # w, h = round(ratio*w), round(ratio*h) | |
| ratio = max_side / max(h, w) | |
| input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) | |
| w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number | |
| h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number | |
| input_image = input_image.resize([w_resize_new, h_resize_new], mode) | |
| if pad_to_max_side: | |
| res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 | |
| offset_x = (max_side - w_resize_new) // 2 | |
| offset_y = (max_side - h_resize_new) // 2 | |
| res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) | |
| input_image = Image.fromarray(res) | |
| return input_image | |
| # prepare models under ./models | |
| instantir_path = f'./models' | |
| # load pretrained models | |
| pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-0.9', torch_dtype=torch.float16) | |
| # load adapter | |
| load_adapter_to_pipe( | |
| pipe, | |
| f"{instantir_path}/adapter.pt", | |
| image_encoder_or_path = 'facebook/dinov2-large', | |
| ) | |
| # load previewer lora | |
| pipe.prepare_previewers(instantir_path) | |
| pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-0.9', subfolder="scheduler") | |
| lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) | |
| # load aggregator weights | |
| pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt") | |
| pipe.aggregator.load_state_dict(pretrained_state_dict) | |
| # send to GPU and fp16 | |
| pipe.to(device='cuda', dtype=torch.float16) | |
| pipe.aggregator.to(device='cuda', dtype=torch.float16) | |
| PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \ | |
| ultra HD, extreme meticulous detailing, skin pore detailing, \ | |
| hyper sharpness, perfect without deformations, \ | |
| taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. " | |
| NEG_PROMPT = "blurry, out of focus, unclear, depth of field, over-smooth, \ | |
| sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \ | |
| dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \ | |
| watermark, signature, jpeg artifacts, deformed, lowres" | |
| def infer(prompt, input_image, steps=30, cfg_scale=7.0, guidance_end=1.0, | |
| creative_restoration=False, seed=3407, height=1024, width=1024): | |
| # load a broken image | |
| low_quality_image = Image.open(input_image).convert("RGB") | |
| lq = [resize_img(low_quality_image, size=(width, height))] | |
| generator = torch.Generator(device='cuda').manual_seed(seed) | |
| timesteps = [ | |
| i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps) | |
| ] | |
| timesteps = timesteps[::-1] | |
| prompt = PROMPT if len(prompt)==0 else prompt | |
| neg_prompt = NEG_PROMPT | |
| # InstantIR restoration | |
| image = pipe( | |
| prompt=[prompt]*len(lq), | |
| image=lq, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| timesteps=timesteps, | |
| negative_prompt=[neg_prompt]*len(lq), | |
| guidance_scale=cfg_scale, | |
| previewer_scheduler=lcm_scheduler, | |
| ).images[0] | |
| return image | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| lq_img = gr.Image(label="Low-quality image", type="filepath") | |
| with gr.Group(): | |
| prompt = gr.Textbox(label="Prompt", value="") | |
| submit_btn = gr.Button("InstantIR magic!") | |
| output_img = gr.Image(label="InstantIR restored") | |
| submit_btn.click( | |
| fn=infer, | |
| inputs=[prompt, lq_img], | |
| outputs=[output_img] | |
| ) | |
| demo.launch(show_error=True) |