Spaces:
Running
on
Zero
Running
on
Zero
| from typing import TypedDict | |
| import diffusers.image_processor | |
| import gradio as gr | |
| import pillow_heif # pyright: ignore[reportMissingTypeStubs] | |
| import spaces # pyright: ignore[reportMissingTypeStubs] | |
| import torch | |
| from PIL import Image | |
| from pipeline import TryOffAnyone | |
| pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType] | |
| pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType] | |
| torch.set_float32_matmul_precision("high") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| TITLE = """ | |
| # Try Off Anyone | |
| ## ⚠️ Important | |
| 1. Choose an example image or upload your own | |
| 2. Use the Pen tool to draw a mask over the clothing area you want to extract | |
| [](https://arxiv.org/abs/2412.08573) | |
| """ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
| pipeline_tryoff = TryOffAnyone( | |
| device=DEVICE, | |
| dtype=DTYPE, | |
| ) | |
| mask_processor = diffusers.image_processor.VaeImageProcessor( | |
| vae_scale_factor=8, | |
| do_normalize=False, | |
| do_binarize=True, | |
| do_convert_grayscale=True, | |
| ) | |
| vae_processor = diffusers.image_processor.VaeImageProcessor( | |
| vae_scale_factor=8, | |
| ) | |
| class ImageData(TypedDict): | |
| background: Image.Image | |
| composite: Image.Image | |
| layers: list[Image.Image] | |
| def process( | |
| image_data: ImageData, | |
| image_width: int, | |
| image_height: int, | |
| num_inference_steps: int, | |
| condition_scale: float, | |
| seed: int, | |
| ) -> Image.Image: | |
| assert image_width > 0 | |
| assert image_height > 0 | |
| assert num_inference_steps > 0 | |
| assert condition_scale > 0 | |
| assert seed >= 0 | |
| # extract image and mask from image_data | |
| image = image_data["background"] | |
| mask = image_data["layers"][0] | |
| # preprocess image | |
| image = image.convert("RGB").resize((image_width, image_height)) | |
| image_preprocessed = vae_processor.preprocess( # pyright: ignore[reportUnknownMemberType,reportAssignmentType] | |
| image=image, | |
| width=image_width, | |
| height=image_height, | |
| )[0] | |
| # preprocess mask | |
| mask = mask.getchannel("A").resize((image_width, image_height)) | |
| mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType] | |
| image=mask, | |
| width=image_width, | |
| height=image_height, | |
| )[0] | |
| # generate the TryOff image | |
| gen = torch.Generator(device=DEVICE).manual_seed(seed) | |
| tryoff_image = pipeline_tryoff( | |
| image_preprocessed, | |
| mask_preprocessed, | |
| inference_steps=num_inference_steps, | |
| scale=condition_scale, | |
| generator=gen, | |
| )[0] | |
| return tryoff_image | |
| with gr.Blocks() as demo: | |
| gr.Markdown(TITLE) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.ImageMask( | |
| label="Input Image", | |
| height=1024, # https://github.com/gradio-app/gradio/issues/10236 | |
| type="pil", | |
| interactive=True, | |
| ) | |
| run_button = gr.Button( | |
| value="Extract Clothing", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/model_1.jpg"], | |
| ["examples/model_2.jpg"], | |
| ["examples/model_3.jpg"], | |
| ["examples/model_4.jpg"], | |
| ["examples/model_5.jpg"], | |
| ["examples/model_6.jpg"], | |
| ["examples/model_7.jpg"], | |
| ["examples/model_8.jpg"], | |
| ["examples/model_9.jpg"], | |
| ], | |
| inputs=[input_image], | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="TryOff result", | |
| height=1024, | |
| image_mode="RGB", | |
| type="pil", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=True): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=100_000, | |
| value=69_420, | |
| step=1, | |
| ) | |
| scale = gr.Slider( | |
| label="Scale", | |
| minimum=0.5, | |
| maximum=5, | |
| value=2.5, | |
| step=0.05, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| value=25, | |
| step=1, | |
| ) | |
| with gr.Row(): | |
| image_width = gr.Slider( | |
| label="Image Width", | |
| minimum=64, | |
| maximum=1024, | |
| value=384, | |
| step=8, | |
| ) | |
| image_height = gr.Slider( | |
| label="Image Height", | |
| minimum=64, | |
| maximum=1024, | |
| value=512, | |
| step=8, | |
| ) | |
| run_button.click( | |
| fn=process, | |
| inputs=[ | |
| input_image, | |
| image_width, | |
| image_height, | |
| num_inference_steps, | |
| scale, | |
| seed, | |
| ], | |
| outputs=output_image, | |
| ) | |
| demo.launch() | |