# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import dataclasses from typing import Literal from dotenv import load_dotenv load_dotenv() from accelerate import Accelerator from transformers import HfArgumentParser from PIL import Image import json import itertools import torch from uso.flux.pipeline import USOPipeline, preprocess_ref from transformers import SiglipVisionModel, SiglipImageProcessor from tqdm import tqdm def horizontal_concat(images): widths, heights = zip(*(img.size for img in images)) total_width = sum(widths) max_height = max(heights) new_im = Image.new("RGB", (total_width, max_height)) x_offset = 0 for img in images: new_im.paste(img, (x_offset, 0)) x_offset += img.size[0] return new_im @dataclasses.dataclass class InferenceArgs: prompt: str | None = None image_paths: list[str] | None = None eval_json_path: str | None = None offload: bool = False num_images_per_prompt: int = 1 model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" width: int = 1024 height: int = 1024 num_steps: int = 25 guidance: float = 4 seed: int = 3407 save_path: str = "output/inference" only_lora: bool = True concat_refs: bool = False lora_rank: int = 128 pe: Literal["d", "h", "w", "o"] = "d" content_ref: int = 512 ckpt_path: str | None = None use_siglip: bool = True instruct_edit: bool = False hf_download: bool = False # set to false, we must not auto download the weights (゜-゜) def main(args: InferenceArgs): accelerator = Accelerator() # init SigLIP model siglip_processor = None siglip_model = None if args.use_siglip: # ⚠️ Weights now load from local paths via .env instead of downloading siglip_path = os.getenv("SIGLIP_PATH", "google/siglip-so400m-patch14-384") siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path) siglip_model = SiglipVisionModel.from_pretrained(siglip_path) siglip_model.eval() siglip_model.to(accelerator.device) print("SigLIP model loaded successfully") pipeline = USOPipeline( args.model_type, accelerator.device, args.offload, only_lora=args.only_lora, lora_rank=args.lora_rank, hf_download=args.hf_download, ) if args.use_siglip and siglip_model is not None: pipeline.model.vision_encoder = siglip_model assert ( args.prompt is not None or args.eval_json_path is not None ), "Please provide either prompt or eval_json_path" if args.eval_json_path is not None: with open(args.eval_json_path, "rt") as f: data_dicts = json.load(f) data_root = os.path.dirname(args.eval_json_path) else: data_root = "" data_dicts = [{"prompt": args.prompt, "image_paths": args.image_paths}] print( f"process: {accelerator.num_processes}/{accelerator.process_index}, \ process images: {len(data_dicts)}/{len(data_dicts[accelerator.process_index::accelerator.num_processes])}" ) data_dicts = data_dicts[accelerator.process_index :: accelerator.num_processes] accelerator.wait_for_everyone() local_task_count = len(data_dicts) * args.num_images_per_prompt if accelerator.is_main_process: progress_bar = tqdm(total=local_task_count, desc="Generating Images") for (i, data_dict), j in itertools.product( enumerate(data_dicts), range(args.num_images_per_prompt) ): ref_imgs = [] for _, img_path in enumerate(data_dict["image_paths"]): if img_path != "": img = Image.open(os.path.join(data_root, img_path)).convert("RGB") ref_imgs.append(img) else: ref_imgs.append(None) siglip_inputs = None if args.use_siglip and siglip_processor is not None: with torch.no_grad(): siglip_inputs = [ siglip_processor(img, return_tensors="pt").to(pipeline.device) for img in ref_imgs[1:] if isinstance(img, Image.Image) ] ref_imgs_pil = [ preprocess_ref(img, args.content_ref) for img in ref_imgs[:1] if isinstance(img, Image.Image) ] if args.instruct_edit: args.width, args.height = ref_imgs_pil[0].size args.width, args.height = args.width * (1024 / args.content_ref), args.height * (1024 / args.content_ref) image_gen = pipeline( prompt=data_dict["prompt"], width=args.width, height=args.height, guidance=args.guidance, num_steps=args.num_steps, seed=args.seed + j, ref_imgs=ref_imgs_pil, pe=args.pe, siglip_inputs=siglip_inputs, ) if args.concat_refs: image_gen = horizontal_concat([image_gen, *ref_imgs]) if "save_dir" in data_dict: config_save_path = os.path.join(args.save_path, data_dict["save_dir"] + f"_{j}.json") image_save_path = os.path.join(args.save_path, data_dict["save_dir"] + f"_{j}.png") else: os.makedirs(args.save_path, exist_ok=True) config_save_path = os.path.join(args.save_path, f"{i}_{j}.json") image_save_path = os.path.join(args.save_path, f"{i}_{j}.png") # save config and image os.makedirs(os.path.dirname(image_save_path), exist_ok=True) image_gen.save(image_save_path) # ensure the prompt and image_paths are saved in the config file args.prompt = data_dict["prompt"] args.image_paths = data_dict["image_paths"] args_dict = vars(args) with open(config_save_path, "w") as f: json.dump(args_dict, f, indent=4) if accelerator.is_main_process: progress_bar.update(1) if accelerator.is_main_process: progress_bar.close() if __name__ == "__main__": parser = HfArgumentParser([InferenceArgs]) args = parser.parse_args_into_dataclasses()[0] main(args)