import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from diffusers import FlowMatchEulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from tqdm import tqdm
import gc
import math
import os
import base64
import json
from optimization import optimize_pipeline_
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
from lora_manager import LoRAManager
# System prompt for prompt enhancement
SYSTEM_PROMPT = '''
# Edit Instruction Rewriter
You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable professional-level edit instruction based on the user-provided instruction and the image to be edited.
Please strictly follow the rewriting rules below:
## 1. General Principles
- Keep the rewritten prompt **concise and comprehensive**. Avoid overly long sentences and unnecessary descriptive language.
- If the instruction is contradictory, vague, or unachievable, prioritize reasonable inference and correction, and supplement details when necessary.
- Keep the main part of the original instruction unchanged, only enhancing its clarity, rationality, and visual feasibility.
- All added objects or modifications must align with the logic and style of the scene in the input images.
- If multiple sub-images are to be generated, describe the content of each sub-image individually.
## 2. Task-Type Handling Rules
### 1. Add, Delete, Replace Tasks
- If the instruction is clear (already includes task type, target entity, position, quantity, attributes), preserve the original intent and only refine the grammar.
- If the description is vague, supplement with minimal but sufficient details (category, color, size, orientation, position, etc.). For example:
> Original: "Add an animal"
> Rewritten: "Add a light-gray cat in the bottom-right corner, sitting and facing the camera"
- Remove meaningless instructions: e.g., "Add 0 objects" should be ignored or flagged as invalid.
- For replacement tasks, specify "Replace Y with X" and briefly describe the key visual features of X.
### 2. Text Editing Tasks
- All text content must be enclosed in English double quotes `" "`. Keep the original language of the text, and keep the capitalization.
- Both adding new text and replacing existing text are text replacement tasks, For example:
- Replace "xx" to "yy"
- Replace the mask / bounding box to "yy"
- Replace the visual object to "yy"
- Specify text position, color, and layout only if user has required.
- If font is specified, keep the original language of the font.
### 3. Human Editing Tasks
- Make the smallest changes to the given user's prompt.
- If changes to background, action, expression, camera shot, or ambient lighting are required, please list each modification individually.
- **Edits to makeup or facial features / expression must be subtle, not exaggerated, and must preserve the subject's identity consistency.**
> Original: "Add eyebrows to the face"
> Rewritten: "Slightly thicken the person's eyebrows with little change, look natural."
### 4. Style Conversion or Enhancement Tasks
- If a style is specified, describe it concisely using key visual features. For example:
> Original: "Disco style"
> Rewritten: "1970s disco style: flashing lights, disco ball, mirrored walls, vibrant colors"
- For style reference, analyze the original image and extract key characteristics (color, composition, texture, lighting, artistic style, etc.), integrating them into the instruction.
- **Colorization tasks (including old photo restoration) must use the fixed template:**
"Restore and colorize the old photo."
- Clearly specify the object to be modified. For example:
> Original: Modify the subject in Picture 1 to match the style of Picture 2.
> Rewritten: "Change the girl in Picture 1 to the ink-wash style of Picture 2 — rendered in black-and-white watercolor with soft color transitions.
### 5. Material Replacement
- Clearly specify the object and the material. For example: "Change the material of the apple to papercut style."
- For text material replacement, use the fixed template:
"Change the material of text "xxxx" to laser style"
### 6. Logo/Pattern Editing
- Material replacement should preserve the original shape and structure as much as possible. For example:
> Original: "Convert to sapphire material"
> Rewritten: "Convert the main subject in the image to sapphire material, preserving similar shape and structure"
- When migrating logos/patterns to new scenes, ensure shape and structure consistency. For example:
> Original: "Migrate the logo in the image to a new scene"
> Rewritten: "Migrate the logo in the image to a new scene, preserving similar shape and structure"
### 7. Multi-Image Tasks
- Rewritten prompts must clearly point out which image's element is being modified. For example:
> Original: "Replace the subject of picture 1 with the subject of picture 2"
> Rewritten: "Replace the girl of picture 1 with the boy of picture 2, keeping picture 2's background unchanged"
- For stylization tasks, describe the reference image's style in the rewritten prompt, while preserving the visual content of the source image.
## 3. Rationale and Logic Check
- Resolve contradictory instructions: e.g., "Remove all trees but keep all trees" requires logical correction.
- Supplement missing critical information: e.g., if position is unspecified, choose a reasonable area based on composition (near subject, blank space, center/edge, etc.).
# Output Format Example
```json
{
"Rewritten": "..."
}
```
'''
def encode_image(pil_image):
"""Encode PIL image to base64 string for API calls"""
import io
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def polish_prompt_hf(prompt, img_list):
"""Rewrite prompt using Hugging Face InferenceClient"""
from huggingface_hub import InferenceClient
# Ensure HF_TOKEN is set
api_key = os.environ.get("HF_TOKEN")
if not api_key:
print("Warning: HF_TOKEN not set. Falling back to original prompt.")
return prompt
try:
# Format the prompt for the API
formatted_prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
# Initialize the client
client = InferenceClient(
provider="novita",
api_key=api_key,
)
# Format the messages for the chat completions API
sys_prompt = "you are a helpful assistant, you should provide useful answers to users."
# Create messages structure
messages = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": []}
]
# Add images to the message
for img in img_list:
messages[1]["content"].append(
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encode_image(img)}"}})
# Add text to the message
messages[1]["content"].append({"type": "text", "text": f"{formatted_prompt}"})
completion = client.chat.completions.create(
model="Qwen/Qwen3-Next-80B-A3B-Instruct",
messages=messages,
)
# Parse the response
result = completion.choices[0].message.content
# Try to extract JSON if present
if '{"Rewritten"' in result:
try:
# Clean up the response
result = result.replace('```json', '').replace('```', '')
result_json = json.loads(result)
polished_prompt = result_json.get('Rewritten', result)
except:
polished_prompt = result
else:
polished_prompt = result
polished_prompt = polished_prompt.strip().replace("\n", " ")
return polished_prompt
except Exception as e:
print(f"Error during API call to Hugging Face: {e}")
# Fallback to original prompt if enhancement fails
return prompt
# Define simplified LoRA configurations with Lightning as always-loaded base
LORA_CONFIG = {
"Lightning (4-Step)": {
"repo_id": "lightx2v/Qwen-Image-Lightning",
"filename": "Qwen-Image-Lightning-4steps-V2.0.safetensors",
"type": "base",
"method": "standard",
"always_load": True,
"prompt_template": "{prompt}",
"description": "Fast 4-step generation LoRA - always loaded as base optimization.",
},
"None": {
"repo_id": None,
"filename": None,
"type": "edit",
"method": "none",
"prompt_template": "{prompt}",
"description": "Use the base Qwen-Image-Edit model with Lightning optimization.",
},
"Object Remover": {
"repo_id": "valiantcat/Qwen-Image-Edit-Remover-General-LoRA",
"filename": "qwen-edit-remover.safetensors",
"type": "edit",
"method": "standard",
"prompt_template": "Remove {prompt}",
"description": "Removes objects from an image while maintaining background consistency.",
},
}
# Initialize LoRA Manager
print("Initializing model...")
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Scheduler configuration for Lightning
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3),
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3),
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
# Initialize scheduler with Lightning config
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
# Load the model pipeline
pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=dtype).to(device)
# Apply the same optimizations from the first version
pipe.transformer.__class__ = QwenImageTransformer2DModel
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
# --- Ahead-of-time compilation ---
optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
# --- UI Constants and Helpers ---
MAX_SEED = np.iinfo(np.int32).max
# Load the model pipeline
pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509",
scheduler=scheduler,
torch_dtype=dtype).to(device)
# Apply model optimizations
pipe.transformer.__class__ = QwenImageTransformer2DModel
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
# Initialize LoRA Manager
lora_manager = LoRAManager(pipe, device)
# Always load Lightning LoRA first
LIGHTNING_LORA_NAME = "Lightning (4-Step)"
print(f"Loading always-active Lightning LoRA: {LIGHTNING_LORA_NAME}")
# Load and register Lightning LoRA
lightning_config = LORA_CONFIG[LIGHTNING_LORA_NAME]
lightning_lora_path = hf_hub_download(
repo_id=lightning_config["repo_id"],
filename=lightning_config["filename"]
)
lora_manager.register_lora(LIGHTNING_LORA_NAME, lightning_lora_path, **lightning_config)
lora_manager.configure_lora(LIGHTNING_LORA_NAME, {
"description": lightning_config["description"],
"is_base": True
})
# Load Lightning LoRA and keep it always active
lora_manager.load_lora(LIGHTNING_LORA_NAME)
lora_manager.fuse_lora(LIGHTNING_LORA_NAME)
# Register other LoRAs (only Object Remover for testing)
for lora_name, config in LORA_CONFIG.items():
if lora_name != LIGHTNING_LORA_NAME and config["repo_id"] is not None:
lora_path = hf_hub_download(repo_id=config["repo_id"], filename=config["filename"])
lora_manager.register_lora(lora_name, lora_path, **config)
original_transformer_state_dict = pipe.transformer.state_dict()
print("Base model and Lightning LoRA loaded and ready.")
def fuse_lora_manual(transformer, lora_state_dict, alpha=1.0):
"""Manual LoRA fusion method"""
key_mapping = {}
for key in lora_state_dict.keys():
base_key = key.replace('diffusion_model.', '').rsplit('.lora_', 1)[0]
if base_key not in key_mapping:
key_mapping[base_key] = {}
if 'lora_A' in key:
key_mapping[base_key]['down'] = lora_state_dict[key]
elif 'lora_B' in key:
key_mapping[base_key]['up'] = lora_state_dict[key]
for name, module in tqdm(transformer.named_modules(), desc="Fusing layers"):
if name in key_mapping and isinstance(module, torch.nn.Linear):
lora_weights = key_mapping[name]
if 'down' in lora_weights and 'up' in lora_weights:
device = module.weight.device
dtype = module.weight.dtype
lora_down = lora_weights['down'].to(device, dtype=dtype)
lora_up = lora_weights['up'].to(device, dtype=dtype)
merged_delta = lora_up @ lora_down
module.weight.data += alpha * merged_delta
return transformer
def load_and_fuse_additional_lora(lora_name):
"""
Load an additional LoRA while keeping Lightning LoRA always active.
This enables combining Lightning's speed with other LoRA capabilities.
"""
config = LORA_CONFIG[lora_name]
print(f"Loading additional LoRA: {lora_name} (Lightning will remain active)")
# Get LoRA path from registry
if lora_name in lora_manager.lora_registry:
lora_path = lora_manager.lora_registry[lora_name]["lora_path"]
else:
print(f"LoRA {lora_name} not found in registry")
return
# Always keep Lightning LoRA loaded
# Load additional LoRA without resetting to base state
if config["method"] == "standard":
print("Using standard loading method...")
# Load additional LoRA without fusing (to preserve Lightning)
pipe.load_lora_weights(lora_path, adapter_names=[lora_name])
# Set both adapters as active
pipe.set_adapters([LIGHTNING_LORA_NAME, lora_name])
print(f"Lightning + {lora_name} now active.")
elif config["method"] == "manual_fuse":
print("Using manual fusion method...")
lora_state_dict = load_file(lora_path)
# Manual fusion on top of Lightning
pipe.transformer = fuse_lora_manual(pipe.transformer, lora_state_dict)
print(f"Lightning + {lora_name} manually fused.")
gc.collect()
torch.cuda.empty_cache()
def load_and_fuse_lora(lora_name):
"""Legacy function for backward compatibility"""
if lora_name == LIGHTNING_LORA_NAME:
# Lightning is already loaded, just ensure it's active
print("Lightning LoRA is already active.")
pipe.set_adapters([LIGHTNING_LORA_NAME])
return
load_and_fuse_additional_lora(lora_name)
# Ahead-of-time compilation with minimal memory footprint
# Use tiny images to minimize memory during compilation
optimize_pipeline_(pipe, image=[Image.new("RGB", (64, 64)), Image.new("RGB", (64, 64))], prompt="test")
print("Model compilation complete.")
@spaces.GPU(duration=60)
def infer(
lora_name,
input_image,
style_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
"""Main inference function with Lightning always active"""
if not lora_name:
raise gr.Error("Please select a LoRA model.")
config = LORA_CONFIG[lora_name]
if config["type"] == "style":
if style_image is None:
raise gr.Error("Style Transfer LoRA requires a Style Reference Image.")
image_for_pipeline = style_image
else: # 'edit' or 'base'
if input_image is None:
raise gr.Error("This LoRA requires an Input Image.")
image_for_pipeline = input_image
if not prompt and config["prompt_template"] != "change the face to face segmentation mask":
raise gr.Error("A text prompt is required for this LoRA.")
# Load additional LoRA while keeping Lightning active
load_and_fuse_lora(lora_name)
final_prompt = config["prompt_template"].format(prompt=prompt)
if randomize_seed:
seed = random.randint(0, np.iinfo(np.int32).max)
generator = torch.Generator(device=device).manual_seed(int(seed))
print("--- Running Inference ---")
print(f"LoRA: {lora_name} (with Lightning always active)")
print(f"Prompt: {final_prompt}")
print(f"Seed: {seed}, Steps: {num_inference_steps}, CFG: {true_guidance_scale}")
with torch.inference_mode():
result_image = pipe(
image=image_for_pipeline,
prompt=final_prompt,
negative_prompt=" ",
num_inference_steps=int(num_inference_steps),
generator=generator,
true_cfg_scale=true_guidance_scale,
).images[0]
# Don't unfuse Lightning - keep it active for next inference
if lora_name != LIGHTNING_LORA_NAME:
pipe.disable_adapters() # Disable additional LoRA but keep Lightning
gc.collect()
torch.cuda.empty_cache()
return result_image, seed
def on_lora_change(lora_name):
"""Dynamic UI component visibility handler"""
config = LORA_CONFIG[lora_name]
is_style_lora = config["type"] == "style"
# Lightning LoRA info
lightning_info = "⚡ **Lightning LoRA always active** - Fast 4-step generation enabled"
return {
lora_description: gr.Markdown(visible=True, value=f"**{lightning_info}** \n\n**Description:** {config['description']}"),
input_image_box: gr.Image(visible=not is_style_lora, type="pil"),
style_image_box: gr.Image(visible=is_style_lora, type="pil"),
prompt_box: gr.Textbox(visible=(config["prompt_template"] != "change the face to face segmentation mask"))
}
with gr.Blocks(css="#col-container { margin: 0 auto; max-width: 1024px; }") as demo:
with gr.Column(elem_id="col-container"):
gr.HTML('
')
gr.Markdown("