|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.utils.checkpoint |
|
|
from accelerate import Accelerator |
|
|
from diffusers import DDIMScheduler |
|
|
from diffusers.utils import check_min_version |
|
|
from safetensors.torch import load_file |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoTokenizer |
|
|
from utils.args_loader import parse_args |
|
|
from utils.dataset import make_dataset |
|
|
from utils.light_controlnet import ControlNetModel |
|
|
from utils.pipeline_controlnet import LightControlNetPipeline |
|
|
from utils.unet_2d_condition import UNet2DConditionNewModel |
|
|
|
|
|
|
|
|
sys.path.append("../../src") |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
|
|
|
check_min_version("0.10.0.dev0") |
|
|
if torch.xpu.is_available(): |
|
|
device = "xpu:0" |
|
|
elif torch.cuda.is_available(): |
|
|
device = "cuda:0" |
|
|
else: |
|
|
device = "cpu" |
|
|
|
|
|
|
|
|
def main(args): |
|
|
logging_dir = Path(args.output_dir, args.logging_dir) |
|
|
|
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
mixed_precision=args.mixed_precision, |
|
|
log_with=args.report_to, |
|
|
project_dir=logging_dir, |
|
|
) |
|
|
|
|
|
|
|
|
if args.tokenizer_name: |
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) |
|
|
elif args.pretrained_model_name_or_path: |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
args.pretrained_model_name_or_path, |
|
|
subfolder="tokenizer", |
|
|
revision=args.revision, |
|
|
use_fast=False, |
|
|
) |
|
|
|
|
|
val_dataset = make_dataset(args, tokenizer, accelerator, "test") |
|
|
|
|
|
controlnet_path = args.controlnet_path |
|
|
unet_path = args.unet_path |
|
|
|
|
|
controlnet = ControlNetModel() |
|
|
controlnet.load_state_dict(load_file(controlnet_path)) |
|
|
unet = UNet2DConditionNewModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") |
|
|
unet = PeftModel.from_pretrained(unet, unet_path, adapter_name=args.adapter_name) |
|
|
|
|
|
pipe = LightControlNetPipeline.from_pretrained( |
|
|
args.pretrained_model_name_or_path, |
|
|
controlnet=controlnet, |
|
|
unet=unet.model, |
|
|
torch_dtype=torch.float32, |
|
|
requires_safety_checker=False, |
|
|
).to(device) |
|
|
|
|
|
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
if not os.path.exists(args.output_dir): |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
exist_lst = [int(img.split("_")[-1][:-4]) for img in os.listdir(args.output_dir)] |
|
|
all_lst = np.arange(len(val_dataset)) |
|
|
idx_lst = [item for item in all_lst if item not in exist_lst] |
|
|
|
|
|
print("Number of images to be processed: ", len(idx_lst)) |
|
|
|
|
|
np.random.seed(seed=int(time.time())) |
|
|
np.random.shuffle(idx_lst) |
|
|
|
|
|
for idx in tqdm(idx_lst): |
|
|
output_path = os.path.join(args.output_dir, f"pred_img_{idx:04d}.png") |
|
|
|
|
|
if not os.path.exists(output_path): |
|
|
data = val_dataset[idx.item()] |
|
|
negative_prompt = "low quality, blurry, unfinished" |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_img = pipe( |
|
|
data["text"], |
|
|
[data["conditioning_pixel_values"]], |
|
|
num_inference_steps=50, |
|
|
guidance_scale=7, |
|
|
negative_prompt=negative_prompt, |
|
|
).images[0] |
|
|
|
|
|
pred_img.save(output_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
main(args) |
|
|
|