Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Fudan University. All rights reserved. | |
| import os | |
| import dataclasses | |
| from typing import Literal | |
| from accelerate import Accelerator | |
| from transformers import HfArgumentParser | |
| from PIL import Image | |
| import json | |
| import itertools | |
| from withanyone.flux.pipeline import WithAnyonePipeline | |
| from util import extract_moref, general_face_preserving_resize, horizontal_concat, extract_object, FaceExtractor | |
| import numpy as np | |
| import random | |
| import torch | |
| from transformers import AutoModelForImageSegmentation | |
| from torch.cuda.amp import autocast | |
| BACK_UP_BBOXES_DOUBLE = [ | |
| [[100,100,200,200], [300,100,400,200]], # 2 faces | |
| [[150,100,250,200], [300,100,400,200]] | |
| ] | |
| BACK_UP_BBOXES = [ # for single face | |
| [[150,100,250,200]], | |
| [[100,100,200,200]], | |
| [[200,100,300,200]], | |
| [[250,100,350,200]], | |
| [[300,100,400,200]], | |
| ] | |
| class InferenceArgs: | |
| prompt: str | None = None | |
| image_paths: list[str] | None = None | |
| eval_json_path: str | None = None | |
| offload: bool = False | |
| num_images_per_prompt: int = 1 | |
| model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" | |
| width: int = 512 | |
| height: int = 512 | |
| ref_size: int = -1 | |
| num_steps: int = 25 | |
| guidance: float = 4 | |
| seed: int = 1234 | |
| save_path: str = "output/inference" | |
| only_lora: bool = True | |
| concat_refs: bool = False | |
| lora_rank: int = 64 | |
| data_resolution: int = 512 | |
| save_iter: str = "500" | |
| use_rec: bool = False | |
| drop_text: bool = False | |
| use_matting: bool = False | |
| id_weight: float = 1.0 | |
| siglip_weight: float = 1.0 | |
| bbox_from_json: bool = True | |
| data_root: str = "./" | |
| # for lora | |
| additional_lora: str | None = None | |
| trigger: str = "" | |
| lora_weight: float = 1.0 | |
| # path to the ipa model | |
| ipa_path: str = "./ckpt/ipa.safetensors" | |
| clip_path: str = "openai/clip-vit-large-patch14" | |
| t5_path: str = "xlabs-ai/xflux_text_encoders" | |
| flux_path: str = "black-forest-labs/FLUX.1-dev" | |
| siglip_path: str = "google/siglip-base-patch16-256-i18n" | |
| def main(args: InferenceArgs): | |
| accelerator = Accelerator() | |
| face_extractor = FaceExtractor() | |
| pipeline = WithAnyonePipeline( | |
| args.model_type, | |
| args.ipa_path, | |
| accelerator.device, | |
| args.offload, | |
| only_lora=args.only_lora, | |
| face_extractor=face_extractor, | |
| additional_lora_ckpt=args.additional_lora, | |
| lora_weight=args.lora_weight, | |
| clip_path=args.clip_path, | |
| t5_path=args.t5_path, | |
| flux_path=args.flux_path, | |
| siglip_path=args.siglip_path, | |
| ) | |
| if args.use_matting: | |
| birefnet = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).to('cuda', dtype=torch.bfloat16) | |
| assert args.prompt is not None or args.eval_json_path is not None, \ | |
| "Please provide either prompt or eval_json_path" | |
| # if args.eval_json_path is not None: | |
| assert args.eval_json_path is not None, "Please provide eval_json_path. This script only supports batch inference." | |
| with open(args.eval_json_path, "rt") as f: | |
| data_dicts = json.load(f) | |
| data_root = args.data_root | |
| metadata = {} | |
| for (i, data_dict), j in itertools.product(enumerate(data_dicts), range(args.num_images_per_prompt)): | |
| if (i * args.num_images_per_prompt + j) % accelerator.num_processes != accelerator.process_index: | |
| continue | |
| # check if exist, if this image is already generated, skip it | |
| # if any of the images are None, skip this image | |
| if not os.path.exists(os.path.join(data_root, data_dict["image_paths"][0])): | |
| print(f"Image {i} does not exist, skipping...") | |
| print("path:", os.path.join(data_root, data_dict["image_paths"][0])) | |
| continue | |
| # extract bbox | |
| ori_img_path = data_dict.get("ori_img_path", None) | |
| # ori_img = Image.open(os.path.join(data_root, data_dict["ori_img_path"])) | |
| # basename = data_dict["ori_img_path"].split(".")[0].replace("/", "_") | |
| if ori_img_path is None: | |
| basename = f"{i}_{j}" | |
| else: | |
| basename = data_dict["ori_img_path"].split(".")[0].replace("/", "_") | |
| ori_img = Image.open(os.path.join(data_root, ori_img_path)) | |
| bboxes = None | |
| print("Processing image", i, basename) | |
| if not args.bbox_from_json: | |
| if ori_img_path is None: | |
| print(f"Image {i} has no ori_img_path, cannot extract bbox, skipping...") | |
| continue | |
| ori_img = Image.open(os.path.join(data_root, ori_img_path)) | |
| bboxes = face_extractor.locate_bboxes(ori_img) | |
| # cut bbox length to num of imgae_paths | |
| if bboxes is not None and len(bboxes) > len(data_dict["image_paths"]): | |
| bboxes = bboxes[:len(data_dict["image_paths"])] | |
| elif bboxes is not None and len(bboxes) < len(data_dict["image_paths"]): | |
| print(f"Image {i} has less faces than image_paths, continuing...") | |
| continue | |
| else: | |
| json_file_path = os.path.join(data_root, basename + ".json") | |
| if os.path.exists(json_file_path): | |
| with open(json_file_path, "r") as f: | |
| json_data = json.load(f) | |
| old_bboxes = json_data.get("bboxes", None) | |
| if old_bboxes is None: | |
| print(f"Image {i} has no bboxes in json file, using backup bboxes...") | |
| # v202 -> 2 faces v200_single -> 1 face | |
| if "v202" in args.eval_json_path: | |
| old_bboxes = random.choice(BACK_UP_BBOXES_DOUBLE) | |
| elif "v200_single" in args.eval_json_path: | |
| old_bboxes = random.choice(BACK_UP_BBOXES) | |
| def recalculate_bbox( bbox, crop): | |
| """ | |
| The image is cropped, so we need to recalculate the bbox. | |
| bbox: [x1, y1, x2, y2] | |
| crop: [x1c, y1c, x2c, y2c] | |
| we just need to minus x1c and y1c from x1, y1, | |
| """ | |
| x1, y1, x2, y2 = bbox | |
| x1c, y1c, x2c, y2c = crop | |
| return [x1-x1c, y1-y1c, x2-x1c, y2-y1c] | |
| crop = json_data.get("crop", None) | |
| rec_bboxes = [ | |
| recalculate_bbox(bbox, crop) if crop is not None else bbox for bbox in old_bboxes] | |
| # face_preserving_resize(image, bboxes, 512) | |
| if ori_img_path is not None: | |
| _, bboxes = general_face_preserving_resize(ori_img, rec_bboxes, 512) | |
| # else we consider the provided bbox is already in target size | |
| else: | |
| bboxes = rec_bboxes | |
| if bboxes is None: | |
| print(f"Image {i} has no face, bboxes are None, using backup bboxes..., basename: {basename}") | |
| bboxes = random.choice(BACK_UP_BBOXES_DOUBLE) | |
| print(f"Use backup bboxes: {bboxes}") | |
| ref_imgs = [] | |
| arcface_embeddings = [] | |
| if not args.use_rec: | |
| break_flag = False | |
| for img_path in data_dict["image_paths"]: | |
| img = Image.open(os.path.join(data_root, img_path)) | |
| ref_img, arcface_embedding = face_extractor.extract(img) | |
| if ref_img is not None and arcface_embedding is not None: | |
| if args.use_matting: | |
| ref_img, _ = extract_object(birefnet, ref_img) | |
| ref_imgs.append(ref_img) | |
| arcface_embeddings.append(arcface_embedding) | |
| else: | |
| print(f"Image {i} has no face, skipping...") | |
| break_flag = True | |
| break | |
| if break_flag: | |
| continue | |
| else: | |
| ref_imgs, arcface_embeddings = face_extractor.extract_refs(ori_img) | |
| if ref_imgs is None or arcface_embeddings is None: | |
| print(f"Image {i} has no face, skipping...") | |
| continue | |
| if args.use_matting: | |
| ref_imgs = [extract_object(birefnet, ref_img)[0] for ref_img in ref_imgs] | |
| # arcface to tensor | |
| arcface_embeddings = [torch.tensor(arcface_embedding) for arcface_embedding in arcface_embeddings] | |
| arcface_embeddings = torch.stack(arcface_embeddings).to(accelerator.device) | |
| # check, if any of the images are None, if so, skip this image | |
| if any(ref_img is None for ref_img in ref_imgs): | |
| print(f"Image {i}: failed to extract face, skipping...") | |
| continue | |
| if args.ref_size==-1: | |
| args.ref_size = 512 if len(ref_imgs)==1 else 320 | |
| if args.trigger != "" and args.trigger is not None: | |
| data_dict["prompt"] = args.trigger + " " + data_dict["prompt"] | |
| image_gen = pipeline( | |
| prompt=data_dict["prompt"] if not args.drop_text else "", | |
| width=args.width, | |
| height=args.height, | |
| guidance=args.guidance, | |
| num_steps=args.num_steps, | |
| seed=args.seed, | |
| ref_imgs=ref_imgs, | |
| arcface_embeddings=arcface_embeddings, | |
| bboxes=[bboxes], | |
| id_weight=args.id_weight, | |
| siglip_weight=args.siglip_weight, | |
| ) | |
| if args.concat_refs: | |
| image_gen = horizontal_concat([image_gen, *ref_imgs]) | |
| os.makedirs(args.save_path, exist_ok=True) | |
| save_path = os.path.join(args.save_path, basename) | |
| os.makedirs(os.path.join(args.save_path, basename), exist_ok=True) | |
| # save refs, image_gen and original image | |
| for k, ref_img in enumerate(ref_imgs): | |
| ref_img.save(os.path.join(save_path, f"ref_{k}.jpg")) | |
| image_gen.save(os.path.join(save_path, f"out.jpg")) | |
| # original image | |
| ori_img = Image.open(os.path.join(data_root, data_dict["ori_img_path"])) if "ori_img_path" in data_dict else None | |
| if ori_img is not None: | |
| ori_img.save(os.path.join(save_path, f"ori.jpg")) | |
| # save config | |
| args_dict = vars(args) | |
| args_dict['prompt'] = data_dict["prompt"] | |
| args_dict["name"] = data_dict["name"] if "name" in data_dict else None | |
| json.dump(args_dict, open(os.path.join(save_path, f"meta.json"), 'w'), indent=4, ensure_ascii=False) | |
| if __name__ == "__main__": | |
| parser = HfArgumentParser([InferenceArgs]) | |
| args = parser.parse_args_into_dataclasses()[0] | |
| main(args) | |