Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| from PIL import Image | |
| import torch | |
| from diffusers import ( | |
| ControlNetModel, | |
| DiffusionPipeline, | |
| StableDiffusionControlNetPipeline, | |
| StableDiffusionXLControlNetPipeline, | |
| UniPCMultistepScheduler, | |
| EulerDiscreteScheduler, | |
| AutoencoderKL | |
| ) | |
| from transformers import DPTFeatureExtractor, DPTForDepthEstimation, DPTImageProcessor | |
| from transformers import CLIPImageProcessor | |
| from diffusers.utils import load_image | |
| from gradio_imageslider import ImageSlider | |
| import boto3 | |
| from io import BytesIO | |
| from datetime import datetime | |
| import json | |
| device = "cuda" | |
| base_model_id = "SG161222/RealVisXL_V5.0" | |
| controlnet_model_id = "diffusers/controlnet-depth-sdxl-1.0" | |
| vae_model_id = "madebyollin/sdxl-vae-fp16-fix" | |
| if torch.cuda.is_available(): | |
| # load pipe | |
| controlnet = ControlNetModel.from_pretrained( | |
| controlnet_model_id, | |
| variant="fp16", | |
| use_safetensors=True, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| vae = AutoencoderKL.from_pretrained(vae_model_id, torch_dtype=torch.bfloat16) | |
| pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
| base_model_id, | |
| controlnet=controlnet, | |
| vae=vae, | |
| variant="fp16", | |
| use_safetensors=True, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe.to(device) | |
| depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") | |
| feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| USE_TORCH_COMPILE = 0 | |
| ENABLE_CPU_OFFLOAD = 0 | |
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return seed | |
| def get_depth_map(image): | |
| original_size = (image.size[1], image.size[0]) | |
| print("start generate depth", original_size) | |
| image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") | |
| with torch.no_grad(), torch.autocast("cuda"): | |
| depth_map = depth_estimator(image).predicted_depth | |
| depth_map = torch.nn.functional.interpolate( | |
| depth_map.unsqueeze(1), | |
| size=original_size, | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) | |
| depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) | |
| depth_map = (depth_map - depth_min) / (depth_max - depth_min) | |
| image = torch.cat([depth_map] * 3, dim=1) | |
| image = image.permute(0, 2, 3, 1).cpu().numpy()[0] | |
| image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) | |
| print("generate depth success") | |
| return image | |
| def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name): | |
| print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name) | |
| connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com" | |
| s3 = boto3.client( | |
| 's3', | |
| endpoint_url=connectionUrl, | |
| region_name='auto', | |
| aws_access_key_id=access_key, | |
| aws_secret_access_key=secret_key | |
| ) | |
| current_time = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png" | |
| buffer = BytesIO() | |
| image.save(buffer, "PNG") | |
| buffer.seek(0) | |
| s3.upload_fileobj(buffer, bucket_name, image_file) | |
| print("upload finish", image_file) | |
| return image_file | |
| def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed, upload_to_s3, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)): | |
| print("process start") | |
| if image_url: | |
| print(image_url) | |
| orginal_image = load_image(image_url) | |
| else: | |
| orginal_image = Image.fromarray(image) | |
| size = (orginal_image.size[0], orginal_image.size[1]) | |
| print("gorinal image size", size) | |
| depth_image = get_depth_map(orginal_image) | |
| generator = torch.Generator().manual_seed(seed) | |
| print(prompt, n_prompt, guidance_scale, num_steps, control_strength) | |
| print("run pipe") | |
| generated_image = pipe( | |
| prompt=prompt, | |
| negative_prompt=n_prompt, | |
| width=size[0], | |
| height=size[1], | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_steps, | |
| strength=control_strength, | |
| generator=generator, | |
| image=depth_image | |
| ).images[0] | |
| print("geneate image success") | |
| if upload_to_s3: | |
| url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket) | |
| result = {"status": "success", "url": url} | |
| else: | |
| result = {"status": "success", "message": "Image generated but not uploaded"} | |
| return generated_image, json.dumps(result) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image() | |
| image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)") | |
| prompt = gr.Textbox(label="Prompt") | |
| run_button = gr.Button("Run") | |
| with gr.Accordion("Advanced options", open=True): | |
| num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=30, step=1) | |
| guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1) | |
| control_strength = gr.Slider(label="Control Strength", minimum=0.1, maximum=4.0, value=0.8, step=0.1) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| n_prompt = gr.Textbox( | |
| label="Negative prompt", | |
| value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", | |
| ) | |
| upload_to_s3 = gr.Checkbox(label="Upload to R2", value=False) | |
| account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id") | |
| access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here") | |
| secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here") | |
| bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here") | |
| with gr.Column(): | |
| result = gr.Image(label="Generated Image") | |
| logs = gr.Textbox(label="logs") | |
| inputs = [ | |
| image, | |
| image_url, | |
| prompt, | |
| n_prompt, | |
| num_steps, | |
| guidance_scale, | |
| control_strength, | |
| seed, | |
| upload_to_s3, | |
| account_id, | |
| access_key, | |
| secret_key, | |
| bucket | |
| ] | |
| run_button.click( | |
| fn=randomize_seed_fn, | |
| inputs=[seed, randomize_seed], | |
| outputs=seed, | |
| queue=False, | |
| api_name=False, | |
| ).then( | |
| fn=process, | |
| inputs=inputs, | |
| outputs=[result, logs], | |
| api_name="predict" | |
| ) | |
| demo.queue().launch() |