LPX55's picture
Update app_local.py
09d4d06 verified
raw
history blame
13.9 kB
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from diffusers import QwenImageEditPipeline
from diffusers.utils import is_xformers_available
import os
import re
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
#############################
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
# Model configuration
REWRITER_MODEL = "Qwen/Qwen1.5-7B-Chat" # Upgraded to 7B for better JSON handling
rewriter_tokenizer = None
rewriter_model = None
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
def load_rewriter():
"""Lazily load the prompt enhancement model"""
global rewriter_tokenizer, rewriter_model
if rewriter_tokenizer is None or rewriter_model is None:
print("🔄 Loading enhancement model...")
rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL)
rewriter_model = AutoModelForCausalLM.from_pretrained(
REWRITER_MODEL,
torch_dtype=dtype,
device_map="auto",
quantization_config=bnb_config
)
print("✅ Enhancement model loaded")
SYSTEM_PROMPT_EDIT = '''
# Edit Instruction Rewriter
You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image.
## 1. General Principles
- Keep the rewritten instruction **concise** and clear.
- Avoid contradictions, vagueness, or unachievable instructions.
- Maintain the core logic of the original instruction; only enhance clarity and feasibility.
- Ensure new added elements or modifications align with the image's original context and art style.
## 2. Task Types
### Add, Delete, Replace:
- When the input is detailed, only refine grammar and clarity.
- For vague instructions, infer minimal but sufficient details.
- For replacement, use the format: `"Replace X with Y"`.
### Text Editing (e.g., text replacement):
- Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`.
- Preserving the original structure and language—**do not translate** or alter style.
### Human Editing (e.g., change a person’s face/hair):
- Preserve core visual identity (gender, ethnic features).
- Describe expressions in subtle and natural terms.
- Maintain key clothing or styling details unless explicitly replaced.
### Style Transformation:
- If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits.
- Use a fixed template for **coloring/restoration**:
`"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"`
if applicable.
## 4. Output Format
Please provide the rewritten instruction in a clean `json` format as:
{
"Rewritten": "..."
}
'''
def extract_json_response(model_output: str) -> str:
"""Extract rewritten instruction from potentially messy JSON output"""
try:
# Try to find the JSON portion in the output
start_idx = model_output.find('{')
end_idx = model_output.rfind('}') + 1
if start_idx == -1 or end_idx == 0:
return None
json_str = model_output[start_idx:end_idx]
# Clean up common formatting issues
json_str = re.sub(r'(?<!")\b(\w+)\b(?=":)', r'"\1"', json_str) # Add quotes to keys
json_str = re.sub(r':\s*([^"{\[]|true|false|null)', r': "\1"', json_str) # Add quotes to values
# Parse JSON
data = json.loads(json_str)
# Extract rewritten prompt from possible key variations
possible_keys = [
"Rewritten", "rewritten", "Rewrited", "rewrited",
"Output", "output", "Enhanced", "enhanced"
]
for key in possible_keys:
if key in data:
return data[key].strip()
# Try nested path
if "Response" in data and "Rewritten" in data["Response"]:
return data["Response"]["Rewritten"].strip()
# Fallback to direct extraction
for value in data.values():
if isinstance(value, str) and 10 < len(value) < 500:
return value.strip()
except Exception:
pass
return None
def polish_prompt(original_prompt: str) -> str:
"""Enhanced prompt rewriting using original system prompt with JSON handling"""
load_rewriter()
# Format as Qwen chat
messages = [
{"role": "system", "content": SYSTEM_PROMPT_EDIT},
{"role": "user", "content": original_prompt}
]
text = rewriter_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = rewriter_model.generate(
**model_inputs,
max_new_tokens=256, # Maintain token count for good JSON generation
do_sample=True,
temperature=0.6,
top_p=0.9,
no_repeat_ngram_size=2,
pad_token_id=rewriter_tokenizer.eos_token_id
)
# Extract and clean response
enhanced = rewriter_tokenizer.decode(
generated_ids[0][model_inputs.input_ids.shape[1]:],
skip_special_tokens=True
).strip()
# Try to extract JSON content
rewritten_prompt = extract_json_response(enhanced)
if rewritten_prompt:
# Clean up substitutions from the JSON output
rewritten_prompt = re.sub(r'(Replace|Change|Add) "([^"]*)"', r'\1 \2', rewritten_prompt)
rewritten_prompt = rewritten_prompt.replace('\\"', '"')
return rewritten_prompt
# Fallback cleanup if JSON extraction fails
print(f"⚠️ JSON extraction failed, using raw output: {enhanced}")
fallback = re.sub(r'```.*?```', '', enhanced, flags=re.DOTALL) # Remove code blocks
fallback = re.sub(r'[\{\}\[\]"]', '', fallback) # Remove JSON artifacts
fallback = fallback.split('\n')[0] # Take first line
# Try to extract before colon separator
if ': ' in fallback:
return fallback.split(': ')[1].strip()
return fallback.strip()
# Load main image editing pipeline
pipe = QwenImageEditPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit",
torch_dtype=dtype
).to(device)
# Load LoRA weights for acceleration
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning",
weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
)
pipe.fuse_lora()
if is_xformers_available():
pipe.enable_xformers_memory_efficient_attention()
else:
print("xformers not available")
def unload_rewriter():
"""Clear enhancement model from memory"""
global rewriter_tokenizer, rewriter_model
if rewriter_model:
del rewriter_tokenizer, rewriter_model
rewriter_tokenizer = None
rewriter_model = None
torch.cuda.empty_cache()
gc.collect()
@spaces.GPU(duration=60)
def infer(
image,
prompt,
seed=42,
randomize_seed=False,
true_guidance_scale=4.0,
num_inference_steps=8,
rewrite_prompt=False,
num_images_per_prompt=1,
):
"""Image editing endpoint with optimized prompt handling"""
original_prompt = prompt
prompt_info = ""
# Handle prompt rewriting
if rewrite_prompt:
try:
enhanced_instruction = polish_prompt(original_prompt)
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
f"<p><strong>Original:</strong> {original_prompt}</p>"
f"<p><strong>Enhanced:</strong> {enhanced_instruction}</p>"
f"</div>"
)
prompt = enhanced_instruction
except Exception as e:
gr.Warning(f"Prompt enhancement failed: {str(e)}")
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>"
f"<p>Using original prompt. Error: {str(e)}</p>"
f"</div>"
)
else:
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; background: #f8f9fa'>"
f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>"
f"<p>{original_prompt}</p>"
f"</div>"
)
# Free VRAM after enhancement
unload_rewriter()
# Set seed for reproducibility
seed_val = seed
if randomize_seed:
seed_val = random.randint(0, 2**32 - 1)
generator = torch.Generator(device=device).manual_seed(seed_val)
try:
# Generate images
edited_images = pipe(
image=image,
prompt=prompt,
negative_prompt=" ",
num_inference_steps=num_inference_steps,
generator=generator,
true_cfg_scale=true_guidance_scale,
num_images_per_prompt=num_images_per_prompt
).images
except Exception as e:
gr.Error(f"Image generation failed: {str(e)}")
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
f"<h4 style='margin-top: 0;'><strong>⚠️ Error:</strong> {str(e)}</h4>"
f"</div>"
)
return [], seed_val, prompt_info
return edited_images, seed_val, prompt_info
MAX_SEED = np.iinfo(np.int32).max
examples = [
"Replace the cat with a friendly golden retriever. Make it look happier, and add more background details.",
"Add text 'Qwen - AI for image editing' in Chinese at the bottom center with a small shadow.",
"Change the style to 1970s vintage, add old photo effect, restore any scratches on the wall or window.",
"Remove the blue sky and replace it with a dark night cityscape.",
"""Replace "Qwen" with "通义" in the Image. Ensure Chinese font is used and position it at top left."""
]
with gr.Blocks(title="Qwen Image Editor Fast") as demo:
gr.Markdown("""
<div style="text-align: center;">
<h1>⚡️ Qwen-Image-Edit Lightning Fast 8-STEP</h1>
<p>8-step image editing with lightx2v's LoRA and local prompt enhancement</p>
<p>🚧 Work in progress, further improvements coming soon.</p>
</div>
""")
with gr.Row():
# Input Column
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
prompt = gr.Textbox(label="Edit Instruction", placeholder="e.g. Add a dog to the right side", lines=2)
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("### Generation Parameters")
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
true_guidance_scale = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=4.0
)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=4, maximum=16, step=1, value=8
)
num_images_per_prompt = gr.Slider(
label="Output Images", minimum=1, maximum=4, step=1, value=2
)
rewrite_toggle = gr.Checkbox(
label="Enable AI Prompt Enhancement",
value=True
)
run_button = gr.Button("Generate Edits", variant="primary")
# Output Column
with gr.Column():
result = gr.Gallery(
label="Output Images",
columns=lambda x: 2 if x > 1 else 1,
object_fit="contain",
height="auto"
)
prompt_info = gr.HTML(
"<div style='margin-top:20px; padding:15px; border-radius:8px; background:#f8f9fa'>"
"<p>Prompt details will appear here after generation</p></div>"
)
# gr.Examples(
# examples=examples,
# inputs=[prompt],
# label="Try These Examples",
# cache_examples=True
# )
# Main processing
run_event = run_button.click(
fn=infer,
inputs=[
input_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
rewrite_toggle,
num_images_per_prompt
],
outputs=[result, seed, prompt_info]
)
prompt.submit(
fn=infer,
inputs=[
input_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
rewrite_toggle,
num_images_per_prompt
],
outputs=[result, seed, prompt_info]
)
# Vectorize prompt info visibility
run_event.then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=[prompt_info],
queue=False
)
if __name__ == "__main__":
demo.launch()