Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Fudan University. All rights reserved. | |
| import dataclasses | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import List, Literal, Optional, Tuple, Union | |
| from io import BytesIO | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFilter | |
| from PIL.JpegImagePlugin import JpegImageFile | |
| from withanyone_kontext_s.flux.pipeline import WithAnyonePipeline | |
| from util import extract_moref, face_preserving_resize | |
| import insightface | |
| def blur_faces_in_image(img, json_data, face_size_threshold=100, blur_radius=15): | |
| """ | |
| Blurs facial areas directly in the original image for privacy protection. | |
| Args: | |
| img: PIL Image or image data | |
| json_data: JSON object with 'bboxes' and 'crop' information | |
| face_size_threshold: Minimum size for faces to be considered (default: 100 pixels) | |
| blur_radius: Strength of the blur effect (higher = more blurred) | |
| Returns: | |
| PIL Image with faces blurred | |
| """ | |
| # Ensure img is a PIL Image | |
| if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor) and not isinstance(img, JpegImageFile): | |
| img = Image.open(BytesIO(img)) | |
| new_bboxes = json_data['bboxes'] | |
| # crop = json_data['crop'] if 'crop' in json_data else [0, 0, img.width, img.height] | |
| # # Recalculate bounding boxes based on crop info | |
| # new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes] | |
| # Check face sizes and filter out faces that are too small | |
| valid_bboxes = [] | |
| for bbox in new_bboxes: | |
| x1, y1, x2, y2 = bbox | |
| if x2 - x1 >= face_size_threshold and y2 - y1 >= face_size_threshold: | |
| valid_bboxes.append(bbox) | |
| # If no valid faces found, return original image | |
| if not valid_bboxes: | |
| return img | |
| # Create a copy of the original image to modify | |
| blurred_img = img.copy() | |
| # Process each face | |
| for bbox in valid_bboxes: | |
| # Convert coordinates to integers | |
| x1, y1, x2, y2 = map(int, bbox) | |
| # Ensure coordinates are within image boundaries | |
| img_width, img_height = img.size | |
| x1 = max(0, x1) | |
| y1 = max(0, y1) | |
| x2 = min(img_width, x2) | |
| y2 = min(img_height, y2) | |
| # Extract the face region | |
| face_region = img.crop((x1, y1, x2, y2)) | |
| # Apply blur to the face region | |
| blurred_face = face_region.filter(ImageFilter.GaussianBlur(radius=blur_radius)) | |
| # Paste the blurred face back into the image | |
| blurred_img.paste(blurred_face, (x1, y1)) | |
| return blurred_img | |
| def captioner(prompt: str, num_person = 1) -> List[List[float]]: | |
| # use random choose for testing | |
| # within 512 | |
| if num_person == 1: | |
| bbox_choices = [ | |
| # expanded, centered and quadrant placements | |
| [96, 96, 288, 288], | |
| [128, 128, 320, 320], | |
| [160, 96, 352, 288], | |
| [96, 160, 288, 352], | |
| [208, 96, 400, 288], | |
| [96, 208, 288, 400], | |
| [192, 160, 368, 336], | |
| [64, 128, 224, 320], | |
| [288, 128, 448, 320], | |
| [128, 256, 320, 448], | |
| [80, 80, 240, 272], | |
| [196, 196, 380, 380], | |
| # originals | |
| [100, 100, 300, 300], | |
| [150, 50, 450, 350], | |
| [200, 100, 500, 400], | |
| [250, 150, 512, 450], | |
| ] | |
| return [bbox_choices[np.random.randint(0, len(bbox_choices))]] | |
| elif num_person == 2: | |
| # realistic side-by-side rows (no vertical stacks or diagonals) | |
| bbox_choices = [ | |
| [[64, 112, 224, 304], [288, 112, 448, 304]], | |
| [[48, 128, 208, 320], [304, 128, 464, 320]], | |
| [[32, 144, 192, 336], [320, 144, 480, 336]], | |
| [[80, 96, 240, 288], [272, 96, 432, 288]], | |
| [[80, 160, 240, 352], [272, 160, 432, 352]], | |
| [[64, 128, 240, 336], [272, 144, 432, 320]], # slight stagger, same row | |
| [[96, 160, 256, 352], [288, 160, 448, 352]], | |
| [[64, 192, 224, 384], [288, 192, 448, 384]], # lower row | |
| [[16, 128, 176, 320], [336, 128, 496, 320]], # near edges | |
| [[48, 120, 232, 328], [280, 120, 464, 328]], | |
| [[96, 160, 240, 336], [272, 160, 416, 336]], # tighter faces | |
| [[72, 136, 232, 328], [280, 152, 440, 344]], # small vertical offset | |
| [[48, 120, 224, 344], [288, 144, 448, 336]], # asymmetric sizes | |
| [[80, 224, 240, 416], [272, 224, 432, 416]], # bottom row | |
| [[80, 64, 240, 256], [272, 64, 432, 256]], # top row | |
| [[96, 176, 256, 368], [288, 176, 448, 368]], | |
| ] | |
| return bbox_choices[np.random.randint(0, len(bbox_choices))] | |
| elif num_person == 3: | |
| # Non-overlapping 3-person layouts within 512x512 | |
| bbox_choices = [ | |
| [[20, 140, 150, 360], [180, 120, 330, 360], [360, 130, 500, 360]], | |
| [[30, 100, 160, 300], [190, 90, 320, 290], [350, 110, 480, 310]], | |
| [[40, 180, 150, 330], [200, 180, 310, 330], [360, 180, 470, 330]], | |
| [[60, 120, 170, 300], [210, 110, 320, 290], [350, 140, 480, 320]], | |
| [[50, 80, 170, 250], [200, 130, 320, 300], [350, 80, 480, 250]], | |
| [[40, 260, 170, 480], [190, 60, 320, 240], [350, 260, 490, 480]], | |
| [[30, 120, 150, 320], [200, 140, 320, 340], [360, 160, 500, 360]], | |
| [[80, 140, 200, 300], [220, 80, 350, 260], [370, 160, 500, 320]], | |
| ] | |
| return bbox_choices[np.random.randint(0, len(bbox_choices))] | |
| elif num_person == 4: | |
| # Non-overlapping 4-person layouts within 512x512 | |
| bbox_choices = [ | |
| [[20, 100, 120, 240], [140, 100, 240, 240], [260, 100, 360, 240], [380, 100, 480, 240]], | |
| [[40, 60, 200, 260], [220, 60, 380, 260], [40, 280, 200, 480], [220, 280, 380, 480]], | |
| [[180, 30, 330, 170], [30, 220, 150, 380], [200, 220, 320, 380], [360, 220, 490, 380]], | |
| [[30, 60, 140, 200], [370, 60, 480, 200], [30, 320, 140, 460], [370, 320, 480, 460]], | |
| [[20, 120, 120, 380], [140, 100, 240, 360], [260, 120, 360, 380], [380, 100, 480, 360]], | |
| [[30, 80, 150, 240], [180, 120, 300, 280], [330, 80, 450, 240], [200, 300, 320, 460]], | |
| [[30, 140, 110, 330], [140, 140, 220, 330], [250, 140, 330, 330], [370, 140, 450, 330]], | |
| [[40, 80, 150, 240], [40, 260, 150, 420], [200, 80, 310, 240], [370, 80, 480, 240]], | |
| ] | |
| return bbox_choices[np.random.randint(0, len(bbox_choices))] | |
| class FaceExtractor: | |
| def __init__(self, model_path="./"): | |
| self.model = insightface.app.FaceAnalysis(name="antelopev2", root="./") | |
| self.model.prepare(ctx_id=0) | |
| def extract(self, image: Image.Image): | |
| """Extract single face and embedding from an image""" | |
| image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| res = self.model.get(image_np) | |
| if len(res) == 0: | |
| return None, None | |
| res = res[0] | |
| bbox = res["bbox"] | |
| moref = extract_moref(image, {"bboxes": [bbox]}, 1) | |
| return moref[0], res["embedding"] | |
| def extract_refs(self, image: Image.Image): | |
| """Extract multiple faces and embeddings from an image""" | |
| image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| res = self.model.get(image_np) | |
| if len(res) == 0: | |
| return None, None, None, None | |
| ref_imgs = [] | |
| arcface_embeddings = [] | |
| bboxes = [] | |
| for r in res: | |
| bbox = r["bbox"] | |
| bboxes.append(bbox) | |
| moref = extract_moref(image, {"bboxes": [bbox]}, 1) | |
| ref_imgs.append(moref[0]) | |
| arcface_embeddings.append(r["embedding"]) | |
| # Convert bboxes to the correct format | |
| new_img, new_bboxes = face_preserving_resize(image, bboxes, 512) | |
| return ref_imgs, arcface_embeddings, new_bboxes, new_img | |
| def resize_bbox(bbox, ori_width, ori_height, new_width, new_height): | |
| """Resize bounding box coordinates while preserving aspect ratio""" | |
| x1, y1, x2, y2 = bbox | |
| # Calculate scaling factors | |
| width_scale = new_width / ori_width | |
| height_scale = new_height / ori_height | |
| # Use minimum scaling factor to preserve aspect ratio | |
| min_scale = min(width_scale, height_scale) | |
| # Calculate offsets for centering the scaled box | |
| width_offset = (new_width - ori_width * min_scale) / 2 | |
| height_offset = (new_height - ori_height * min_scale) / 2 | |
| # Scale and adjust coordinates | |
| new_x1 = int(x1 * min_scale + width_offset) | |
| new_y1 = int(y1 * min_scale + height_offset) | |
| new_x2 = int(x2 * min_scale + width_offset) | |
| new_y2 = int(y2 * min_scale + height_offset) | |
| return [new_x1, new_y1, new_x2, new_y2] | |
| def draw_bboxes_on_image(image, bboxes): | |
| """Draw bounding boxes on image for visualization""" | |
| if bboxes is None: | |
| return image | |
| # Create a copy to draw on | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| # Draw each bbox with a different color | |
| colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] | |
| for i, bbox in enumerate(bboxes): | |
| color = colors[i % len(colors)] | |
| x1, y1, x2, y2 = [int(coord) for coord in bbox] | |
| # Draw rectangle | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
| # Draw label | |
| draw.text((x1, y1-15), f"Face {i+1}", fill=color) | |
| return img_draw | |
| def create_demo( | |
| model_type: str = "flux-dev", | |
| ipa_path: str = "./ckpt/ipa.safetensors", | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
| offload: bool = False, | |
| lora_rank: int = 64, | |
| additional_lora_ckpt: Optional[str] = None, | |
| lora_scale: float = 1.0, | |
| clip_path: str = "openai/clip-vit-large-patch14", | |
| t5_path: str = "xlabs-ai/xflux_text_encoders", | |
| flux_path: str = "black-forest-labs/FLUX.1-dev", | |
| ): | |
| face_extractor = FaceExtractor() | |
| # Initialize pipeline and face extractor | |
| pipeline = WithAnyonePipeline( | |
| model_type, | |
| ipa_path, | |
| device, | |
| offload, | |
| only_lora=True, | |
| no_lora=True, | |
| lora_rank=lora_rank, | |
| additional_lora_ckpt=additional_lora_ckpt, | |
| lora_weight=lora_scale, | |
| face_extractor=face_extractor, | |
| clip_path=clip_path, | |
| t5_path=t5_path, | |
| flux_path=flux_path, | |
| ) | |
| def parse_bboxes(bbox_text): | |
| """Parse bounding box text input""" | |
| if not bbox_text or bbox_text.strip() == "": | |
| return None | |
| try: | |
| bboxes = [] | |
| lines = bbox_text.strip().split("\n") | |
| for line in lines: | |
| if not line.strip(): | |
| continue | |
| coords = [float(x) for x in line.strip().split(",")] | |
| if len(coords) != 4: | |
| raise ValueError(f"Each bbox must have 4 coordinates (x1,y1,x2,y2), got: {line}") | |
| bboxes.append(coords) | |
| return bboxes | |
| except Exception as e: | |
| raise gr.Error(f"Invalid bbox format: {e}") | |
| def extract_from_base_image(base_img): | |
| """Extract references and bboxes from the base image""" | |
| if base_img is None: | |
| return None, None, None, None | |
| # Convert from numpy to PIL if needed | |
| if isinstance(base_img, np.ndarray): | |
| base_img = Image.fromarray(base_img) | |
| ref_imgs, arcface_embeddings, bboxes, new_img = face_extractor.extract_refs(base_img) | |
| if ref_imgs is None or len(ref_imgs) == 0: | |
| raise gr.Error("No faces detected in the base image") | |
| # Limit to max 4 faces | |
| ref_imgs = ref_imgs[:4] | |
| arcface_embeddings = arcface_embeddings[:4] | |
| bboxes = bboxes[:4] | |
| # Create visualization with bboxes | |
| viz_image = draw_bboxes_on_image(new_img, bboxes) | |
| # Format bboxes as string for display | |
| bbox_text = "\n".join([f"{bbox[0]:.1f},{bbox[1]:.1f},{bbox[2]:.1f},{bbox[3]:.1f}" for bbox in bboxes]) | |
| return ref_imgs, arcface_embeddings, bboxes, viz_image, bbox_text | |
| def process_and_generate( | |
| prompt, | |
| guidance, num_steps, seed, | |
| ref_img1, ref_img2, ref_img3, ref_img4, | |
| base_img, | |
| manual_bboxes_text, | |
| use_text_prompt, | |
| siglip_weight | |
| ): | |
| # Validate base_img is provided | |
| if base_img is None: | |
| raise gr.Error("Base image is required") | |
| # Convert numpy to PIL if needed | |
| if isinstance(base_img, np.ndarray): | |
| base_img = Image.fromarray(base_img) | |
| # Get dimensions from base_img | |
| width, height = base_img.size | |
| # Collect and validate reference images | |
| ref_images = [img for img in [ref_img1, ref_img2, ref_img3, ref_img4] if img is not None] | |
| if not ref_images: | |
| raise gr.Error("At least one reference image is required") | |
| # Process reference images to extract face and embeddings | |
| ref_imgs = [] | |
| arcface_embeddings = [] | |
| # Extract bboxes from the base image | |
| extracted_refs, extracted_embeddings, bboxes_, _, _ = extract_from_base_image(base_img) | |
| bboxes__ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes_] | |
| if extracted_refs is None: | |
| raise gr.Error("No faces detected in the base image. Please provide a different base image with clear faces.") | |
| # Create blurred canvas by blurring faces in the base image | |
| blurred_canvas = blur_faces_in_image(base_img, {'bboxes': bboxes__}) | |
| bboxes = [bboxes__] # Wrap in list for batch input format | |
| # Process each reference image | |
| for img in ref_images: | |
| if isinstance(img, np.ndarray): | |
| img = Image.fromarray(img) | |
| ref_img, embedding = face_extractor.extract(img) | |
| if ref_img is None or embedding is None: | |
| raise gr.Error("Failed to extract face from one of the reference images") | |
| ref_imgs.append(ref_img) | |
| arcface_embeddings.append(embedding) | |
| if len(bboxes[0]) != len(ref_imgs): | |
| raise gr.Error(f"Number of bboxes ({len(bboxes[0])}) must match number of reference images ({len(ref_imgs)})") | |
| # Convert arcface embeddings to tensor | |
| arcface_embeddings = [torch.tensor(embedding) for embedding in arcface_embeddings] | |
| arcface_embeddings = torch.stack(arcface_embeddings).to(device) | |
| # Generate image | |
| final_prompt = prompt if use_text_prompt else "" | |
| if seed < 0: | |
| seed = np.random.randint(0, 1000000) | |
| image_gen = pipeline( | |
| prompt=final_prompt, | |
| width=width, | |
| height=height, | |
| guidance=guidance, | |
| num_steps=num_steps, | |
| seed=seed if seed > 0 else None, | |
| ref_imgs=ref_imgs, | |
| img_cond=blurred_canvas, # Pass the blurred canvas image | |
| arcface_embeddings=arcface_embeddings, | |
| bboxes=bboxes, | |
| max_num_ids=len(ref_imgs), | |
| siglip_weight=0, | |
| id_weight=1, # only arcface supported now | |
| arc_only=True, | |
| ) | |
| # Save temp file for download | |
| temp_path = "temp_generated.png" | |
| image_gen.save(temp_path) | |
| # draw bboxes on the generated image for debug | |
| debug_face = draw_bboxes_on_image(image_gen, bboxes[0]) | |
| return image_gen, debug_face, temp_path | |
| def update_bbox_display(base_img): | |
| if base_img is None: | |
| return None, None | |
| try: | |
| _, _, _, viz_image, bbox_text = extract_from_base_image(base_img) | |
| return viz_image, bbox_text | |
| except Exception as e: | |
| return None, None | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# WithAnyone Kontext Demo") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input controls | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| siglip_weight = 0.0 | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt", value="a person in a beautiful garden. High resolution, extremely detailed") | |
| use_text_prompt = gr.Checkbox(label="Use text prompt", value=True) | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Row(): | |
| num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") | |
| guidance = gr.Slider(1.0, 10.0, 4.0, step=0.1, label="Guidance") | |
| seed = gr.Number(-1, label="Seed (-1 for random)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Reference image inputs | |
| gr.Markdown("### Face References (1-4 required)") | |
| ref_img1 = gr.Image(label="Reference 1", type="pil") | |
| ref_img2 = gr.Image(label="Reference 2", type="pil", visible=True) | |
| ref_img3 = gr.Image(label="Reference 3", type="pil", visible=True) | |
| ref_img4 = gr.Image(label="Reference 4", type="pil", visible=True) | |
| with gr.Column(): | |
| # Base image input - combines the previous canvas and multi-person image | |
| gr.Markdown("### Base Image (Required)") | |
| base_img = gr.Image(label="Base Image - faces will be detected and replaced", type="pil") | |
| bbox_preview = gr.Image(label="Detected Faces", type="pil") | |
| gr.Markdown("### Manual Bounding Box Override (Optional)") | |
| manual_bbox_input = gr.Textbox( | |
| label="Manual Bounding Boxes (one per line, format: x1,y1,x2,y2)", | |
| lines=4, | |
| placeholder="100,100,200,200\n300,100,400,200" | |
| ) | |
| with gr.Column(): | |
| # Output display | |
| output_image = gr.Image(label="Generated Image") | |
| debug_face = gr.Image(label="Debug: Faces are expected to be generated in these boxes") | |
| download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False) | |
| # Examples section | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| # Example Configurations | |
| ### Tips for Better Results | |
| - Base image is required - faces in this image will be detected, blurred, and then replaced | |
| - Provide clear reference images with visible faces | |
| - Use detailed prompts describing the desired output | |
| - Adjust the resemblance slider based on your needs - more to the right for closer facial resemblance | |
| """) | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples=[ | |
| [ | |
| "", # prompt | |
| 4.0, 25, 42, # guidance, num_steps, seed | |
| "assets/ref3.jpg", "assets/ref1.jpg", None, None, # ref images | |
| "assets/canvas.jpg", # base image | |
| False, # use_text_prompt | |
| ] | |
| ], | |
| inputs=[ | |
| prompt, guidance, num_steps, seed, | |
| ref_img1, ref_img2, ref_img3, ref_img4, | |
| base_img, use_text_prompt | |
| ], | |
| label="Click to load example configurations" | |
| ) | |
| # Set up event handlers | |
| base_img.change( | |
| fn=update_bbox_display, | |
| inputs=[base_img], | |
| outputs=[bbox_preview, manual_bbox_input] | |
| ) | |
| generate_btn.click( | |
| fn=process_and_generate, | |
| inputs=[ | |
| prompt, guidance, num_steps, seed, | |
| ref_img1, ref_img2, ref_img3, ref_img4, | |
| base_img, use_text_prompt, | |
| ], | |
| outputs=[output_image, debug_face, download_btn] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| from transformers import HfArgumentParser | |
| class AppArgs: | |
| model_type: Literal["flux-dev", "flux-kontext", "flux-schnell"] = "flux-kontext" | |
| device: Literal["cuda", "cpu"] = ( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| offload: bool = False | |
| lora_rank: int = 64 | |
| port: int = 7860 | |
| additional_lora: str = None | |
| lora_scale: float = 1.0 | |
| ipa_path: str = "./ckpt/ipa.safetensors" | |
| clip_path: str = "openai/clip-vit-large-patch14" | |
| t5_path: str = "xlabs-ai/xflux_text_encoders" | |
| flux_path: str = "black-forest-labs/FLUX.1-dev" | |
| parser = HfArgumentParser([AppArgs]) | |
| args = parser.parse_args_into_dataclasses()[0] | |
| demo = create_demo( | |
| args.model_type, | |
| args.ipa_path, | |
| args.device, | |
| args.offload, | |
| args.lora_rank, | |
| args.additional_lora, | |
| args.lora_scale, | |
| args.clip_path, | |
| args.t5_path, | |
| args.flux_path, | |
| ) | |
| demo.launch(server_port=args.port) |