DreamOmni2-Gen / inference_edit.py
wcy1122's picture
initial commi
26a63c0
import torch
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
import importlib
import transformers.utils
import transformers.models
origin_utils = transformers.utils
origin_models = transformers.models
import flash_attn
flash_attn.hack_transformers_flash_attn_2_available_check()
importlib.reload(transformers.utils)
importlib.reload(transformers.models)
origin_func = torch.nn.functional.interpolate
def new_func(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
if mode == "bilinear":
dtype = input.dtype
res = origin_func(input.to(torch.bfloat16), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
return res.to(dtype)
else:
return origin_func(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
torch.nn.functional.interpolate = new_func
from utils import patch_npu_record_stream
from utils import patch_npu_diffusers_get_1d_rotary_pos_embed
patch_npu_record_stream()
patch_npu_diffusers_get_1d_rotary_pos_embed()
USE_NPU = True
except:
USE_NPU = False
from dreamomni2.pipeline_dreamomni2 import DreamOmni2Pipeline
from diffusers.utils import load_image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
# from qwen_vl_utils import process_vision_info
from utils.vprocess import process_vision_info, resizeinput
import os
import argparse
from tqdm import tqdm
import json
from PIL import Image
import re
import argparse
if USE_NPU:
device = "npu"
else:
device = "cuda"
def extract_gen_content(text):
text = text[6:-7]
return text
def parse_args():
"""Parses command-line arguments for model paths and server configuration."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--vlm_path",
type=str,
default="./models/vlm-model",
help="Path to the VLM model directory."
)
parser.add_argument(
"--edit_lora_path",
type=str,
default="./models/edit_lora",
help="Path to the FLUX.1-Kontext editing LoRA weights directory."
)
parser.add_argument(
"--base_model_path",
type=str,
default="black-forest-labs/FLUX.1-Kontext-dev",
help="Path to the FLUX.1-Kontext editing."
)
parser.add_argument(
"--input_img_path",
type=str,
nargs='+', # Accept one or more input paths
default=["example_input/edit_tests/src.jpg", "example_input/edit_tests/ref.jpg"],
help="List of input image paths (e.g., src and ref images)."
)
# Argument for the input instruction
parser.add_argument(
"--input_instruction",
type=str,
default="Make the woman from the second image stand on the road in the first image.",
help="Instruction for image editing."
)
# Argument for the output image path
parser.add_argument(
"--output_path",
type=str,
default="example_input/edit_tests/edi_res.png",
help="Path to save the output image."
)
args = parser.parse_args()
return args
ARGS = parse_args()
vlm_path = ARGS.vlm_path
edit_lora_path = ARGS.edit_lora_path
base_model = ARGS.base_model_path
pipe = DreamOmni2Pipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
pipe.to(device)
pipe.load_lora_weights(
edit_lora_path,
adapter_name="edit"
)
pipe.set_adapters(["edit"], adapter_weights=[1])
vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
vlm_path, torch_dtype="bfloat16", device_map="cuda"
)
processor = AutoProcessor.from_pretrained(vlm_path)
def infer_vlm(input_img_path,input_instruction,prefix):
tp=[]
for path in input_img_path:
tp.append({"type": "image", "image": path})
tp.append({"type": "text", "text": input_instruction+prefix})
messages = [
{
"role": "user",
"content": tp,
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference
generated_ids = vlm_model.generate(**inputs, do_sample=False, max_new_tokens=4096)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
def infer(source_imgs,prompt):
image = pipe(
images=source_imgs,
height=source_imgs[0].height,
width=source_imgs[0].width,
prompt=prompt,
num_inference_steps=30,
guidance_scale=3.5,
).images[0]
return image
input_img_path=ARGS.input_img_path
input_instruction=ARGS.input_instruction
prefix=" It is editing task."
source_imgs = []
for path in input_img_path:
img = load_image(path)
# source_imgs.append(img)
source_imgs.append(resizeinput(img))
prompt=infer_vlm(input_img_path,input_instruction,prefix)
prompt = extract_gen_content(prompt)
image=infer(source_imgs,prompt)
output_path = ARGS.output_path
image.save(output_path)