Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| import spaces | |
| from PIL import Image | |
| from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import is_xformers_available | |
| import os | |
| import sys | |
| import re | |
| import gc | |
| import json # Added json import | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| import logging | |
| ############################# | |
| os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') | |
| os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1') | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Model configuration | |
| REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" # Upgraded to 4B for better JSON handling | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| LOC = os.getenv("QIE") | |
| # Quantization configuration | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| rewriter_model = AutoModelForCausalLM.from_pretrained( | |
| REWRITER_MODEL, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| ) | |
| # Preload enhancement model at startup | |
| print("🔄 Loading prompt enhancement model...") | |
| rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL) | |
| print("✅ Enhancement model loaded and ready!") | |
| SYSTEM_PROMPT_EDIT = ''' | |
| # Edit Instruction Rewriter | |
| You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image. | |
| ## 1. General Principles | |
| - Keep the rewritten instruction **concise** and clear. | |
| - Avoid contradictions, vagueness, or unachievable instructions. | |
| - Maintain the core logic of the original instruction; only enhance clarity and feasibility. | |
| - Ensure new added elements or modifications align with the image's original context and art style. | |
| ## 2. Task Types | |
| ### Add, Delete, Replace: | |
| - When the input is detailed, only refine grammar and clarity. | |
| - For vague instructions, infer minimal but sufficient details. | |
| - For replacement, use the format: `"Replace X with Y"`. | |
| ### Text Editing (e.g., text replacement): | |
| - Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`. | |
| - Preserving the original structure and language—**do not translate** or alter style. | |
| ### Human Editing (e.g., change a person’s face/hair): | |
| - Preserve core visual identity (gender, ethnic features). | |
| - Describe expressions in subtle and natural terms. | |
| - Maintain key clothing or styling details unless explicitly replaced. | |
| ### Style Transformation: | |
| - If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits. | |
| - Use a fixed template for **coloring/restoration**: | |
| `"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"` | |
| if applicable. | |
| ## 4. Output Format | |
| Please provide the rewritten instruction in a clean `json` format as: | |
| { | |
| "Rewritten": "..." | |
| } | |
| ''' | |
| def extract_json_response(model_output: str) -> str: | |
| """Extract rewritten instruction from potentially messy JSON output""" | |
| # Remove code block markers first | |
| model_output = re.sub(r'```(?:json)?\s*', '', model_output) | |
| try: | |
| # Find the JSON portion in the output | |
| start_idx = model_output.find('{') | |
| end_idx = model_output.rfind('}') | |
| # Fix the condition - check if brackets were found | |
| if start_idx == -1 or end_idx == -1 or start_idx >= end_idx: | |
| print(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}") | |
| return None | |
| # Expand to the full object including outer braces | |
| end_idx += 1 # Include the closing brace | |
| json_str = model_output[start_idx:end_idx] | |
| # Handle potential markdown or other formatting | |
| json_str = json_str.strip() | |
| # Try to parse JSON directly first | |
| try: | |
| data = json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| print(f"Direct JSON parsing failed: {e}") | |
| # If direct parsing fails, try cleanup | |
| # Quote keys properly | |
| json_str = re.sub(r'([^{}[\],\s"]+)(?=\s*:)', r'"\1"', json_str) | |
| # Remove any trailing commas that might cause issues | |
| json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) | |
| # Try parsing again | |
| data = json.loads(json_str) | |
| # Extract rewritten prompt from possible key variations | |
| possible_keys = [ | |
| "Rewritten", "rewritten", "Rewrited", "rewrited", "Rewrittent", | |
| "Output", "output", "Enhanced", "enhanced" | |
| ] | |
| for key in possible_keys: | |
| if key in data: | |
| return data[key].strip() | |
| # Try nested path | |
| if "Response" in data and "Rewritten" in data["Response"]: | |
| return data["Response"]["Rewritten"].strip() | |
| # Handle nested JSON objects (additional protection) | |
| if isinstance(data, dict): | |
| for value in data.values(): | |
| if isinstance(value, dict) and "Rewritten" in value: | |
| return value["Rewritten"].strip() | |
| # Try to find any string value that looks like an instruction | |
| str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500] | |
| if str_values: | |
| return str_values[0].strip() | |
| except Exception as e: | |
| print(f"JSON parse error: {str(e)}") | |
| print(f"Model output was: {model_output}") | |
| return None | |
| def polish_prompt(original_prompt: str) -> str: | |
| """Enhanced prompt rewriting using original system prompt with JSON handling""" | |
| # Format as Qwen chat | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT_EDIT}, | |
| {"role": "user", "content": original_prompt} | |
| ] | |
| text = rewriter_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = rewriter_model.generate( | |
| **model_inputs, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.8, | |
| repetition_penalty=1.1, | |
| no_repeat_ngram_size=3, | |
| pad_token_id=rewriter_tokenizer.eos_token_id | |
| ) | |
| # Extract and clean response | |
| enhanced = rewriter_tokenizer.decode( | |
| generated_ids[0][model_inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| print(f"Model raw output: {enhanced}") # Debug logging | |
| # Try to extract JSON content | |
| rewritten_prompt = extract_json_response(enhanced) | |
| if rewritten_prompt: | |
| # Clean up remaining artifacts | |
| rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt) | |
| rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ') | |
| return rewritten_prompt | |
| else: | |
| # Fallback: try to extract from code blocks or just return cleaned content | |
| if '```' in enhanced: | |
| parts = enhanced.split('```') | |
| if len(parts) >= 2: | |
| rewritten_prompt = parts[1].strip() | |
| else: | |
| rewritten_prompt = enhanced | |
| else: | |
| rewritten_prompt = enhanced | |
| # Basic cleanup | |
| rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip() | |
| if ': ' in rewritten_prompt: | |
| rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip() | |
| return rewritten_prompt[:200] if rewritten_prompt else original_prompt | |
| # Scheduler configuration for Lightning | |
| scheduler_config = { | |
| "base_image_seq_len": 256, | |
| "base_shift": math.log(3), | |
| "invert_sigmas": False, | |
| "max_image_seq_len": 8192, | |
| "max_shift": math.log(3), | |
| "num_train_timesteps": 1000, | |
| "shift": 1.0, | |
| "shift_terminal": None, | |
| "stochastic_sampling": False, | |
| "time_shift_type": "exponential", | |
| "use_beta_sigmas": False, | |
| "use_dynamic_shifting": True, | |
| "use_exponential_sigmas": False, | |
| "use_karras_sigmas": False, | |
| } | |
| # Initialize scheduler with Lightning config | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) | |
| # Load main image editing pipeline | |
| pipe = QwenImageEditPipeline.from_pretrained( | |
| LOC, | |
| scheduler=scheduler, | |
| torch_dtype=dtype | |
| ).to(device) | |
| # Load LoRA weights for acceleration | |
| pipe.load_lora_weights( | |
| "lightx2v/Qwen-Image-Lightning", | |
| weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" | |
| ) | |
| pipe.fuse_lora() | |
| if is_xformers_available(): | |
| pipe.enable_xformers_memory_efficient_attention() | |
| else: | |
| print("xformers not available") | |
| # def unload_rewriter(): | |
| # """Clear enhancement model from memory""" | |
| # global rewriter_tokenizer, rewriter_model | |
| # if rewriter_model: | |
| # del rewriter_tokenizer, rewriter_model | |
| # rewriter_tokenizer = None | |
| # rewriter_model = None | |
| # torch.cuda.empty_cache() | |
| # gc.collect() | |
| def infer( | |
| image, | |
| prompt, | |
| seed=42, | |
| randomize_seed=False, | |
| true_guidance_scale=4.0, | |
| num_inference_steps=8, | |
| rewrite_prompt=True, | |
| num_images_per_prompt=1, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Image editing endpoint with optimized prompt handling""" | |
| # Resize image to max 1024px on longest side | |
| def resize_image(pil_image, max_size=1024): | |
| """Resize image to maximum dimension of 1024px while maintaining aspect ratio""" | |
| try: | |
| if pil_image is None: | |
| return pil_image | |
| width, height = pil_image.size | |
| max_dimension = max(width, height) | |
| if max_dimension <= max_size: | |
| return pil_image # No resize needed | |
| # Calculate new dimensions maintaining aspect ratio | |
| scale = max_size / max_dimension | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| # Resize image | |
| resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS) | |
| print(f"📝 Image resized from {width}x{height} to {new_width}x{new_height}") | |
| return resized_image | |
| except Exception as e: | |
| print(f"⚠️ Image resize failed: {e}") | |
| return pil_image # Return original if resize fails | |
| # Add noise function for batch variation | |
| def add_noise_to_image(pil_image, noise_level=0.05): | |
| """Add slight noise to image to create variation in outputs""" | |
| try: | |
| if pil_image is None: | |
| return pil_image | |
| img_array = np.array(pil_image).astype(np.float32) / 255.0 | |
| noise = np.random.normal(0, noise_level, img_array.shape) | |
| noisy_array = img_array + noise | |
| # Clip values to valid range | |
| noisy_array = np.clip(noisy_array, 0, 1) | |
| # Convert back to PIL | |
| noisy_array = (noisy_array * 255).astype(np.uint8) | |
| return Image.fromarray(noisy_array) | |
| except Exception as e: | |
| print(f"Warning: Could not add noise to image: {e}") | |
| return pil_image # Return original if noise addition fails | |
| # Resize input image first | |
| image = resize_image(image, max_size=1024) | |
| original_prompt = prompt | |
| prompt_info = "" | |
| # Handle prompt rewriting | |
| if rewrite_prompt: | |
| try: | |
| enhanced_instruction = polish_prompt(original_prompt) | |
| if enhanced_instruction and enhanced_instruction != original_prompt: | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>" | |
| f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>" | |
| f"<p><strong>Original:</strong> {original_prompt}</p>" | |
| f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>" | |
| f"</div>" | |
| ) | |
| prompt = enhanced_instruction | |
| else: | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800; background: #fff8f0'>" | |
| f"<h4 style='margin-top: 0;'>📝 Prompt Enhancement</h4>" | |
| f"<p>No enhancement applied or enhancement failed</p>" | |
| f"</div>" | |
| ) | |
| except Exception as e: | |
| print(f"Prompt enhancement error: {str(e)}") # Debug logging | |
| gr.Warning(f"Prompt enhancement failed: {str(e)}") | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>" | |
| f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>" | |
| f"<p>Using original prompt. Error: {str(e)[:100]}</p>" | |
| f"</div>" | |
| ) | |
| else: | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:10px; border-radius:8px; background: #f8f9fa'>" | |
| f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>" | |
| f"<p>{original_prompt}</p>" | |
| f"</div>" | |
| ) | |
| # Set base seed for reproducibility | |
| base_seed = seed if not randomize_seed else random.randint(0, MAX_SEED) | |
| try: | |
| # Generate images with variation for batch mode | |
| if num_images_per_prompt > 1: | |
| edited_images = [] | |
| for i in range(num_images_per_prompt): | |
| # Create unique seed for each image | |
| generator = torch.Generator(device=device).manual_seed(base_seed + i*1000) | |
| # Add slight noise to the image for variation | |
| noisy_image = add_noise_to_image(image, noise_level=0.05 + i*0.003) | |
| # Slightly vary guidance scale | |
| varied_guidance = true_guidance_scale + random.uniform(-0.5, 0.5) | |
| varied_guidance = max(1.0, min(10.0, varied_guidance)) | |
| # Generate single image with variations | |
| result = pipe( | |
| image=noisy_image, | |
| prompt=prompt, | |
| negative_prompt=" ", | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| true_cfg_scale=varied_guidance, | |
| num_images_per_prompt=1 | |
| ).images | |
| edited_images.extend(result) | |
| else: | |
| # Single image generation (unchanged) | |
| generator = torch.Generator(device=device).manual_seed(base_seed) | |
| edited_images = pipe( | |
| image=image, | |
| prompt=prompt, | |
| negative_prompt=" ", | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| true_cfg_scale=true_guidance_scale, | |
| num_images_per_prompt=num_images_per_prompt | |
| ).images | |
| # Clear cache after generation | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return edited_images, base_seed, prompt_info | |
| except Exception as e: | |
| # Clear cache on error | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| gr.Error(f"Image generation failed: {str(e)}") | |
| return [], base_seed, ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>" | |
| f"<h4 style='margin-top: 0;'>⚠️ Processing Error</h4>" | |
| f"<p>{str(e)[:200]}</p>" | |
| f"</div>" | |
| ) | |
| with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo: | |
| gr.Markdown(""" | |
| <div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;"> | |
| <h1 style="margin-bottom: 5px;">⚡️ Qwen-Image-Edit Lightning</h1> | |
| <p>✨ 8-step inferencing with lightx2v's LoRA.</p> | |
| <p>📝 Local Prompt Enhancement, Batched Multi-image Generation</p> | |
| </div> | |
| """) | |
| with gr.Row(equal_height=True): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="Source Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| prompt = gr.Textbox( | |
| label="Edit Instructions", | |
| placeholder="e.g. Replace the background with a beach sunset...", | |
| lines=2, | |
| max_lines=4 | |
| ) | |
| with gr.Row(): | |
| rewrite_toggle = gr.Checkbox( | |
| label="Enable Prompt Enhancement", | |
| value=True, | |
| interactive=True | |
| ) | |
| run_button = gr.Button( | |
| "Generate Edits", | |
| variant="primary", | |
| min_width=120 | |
| ) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Random Seed", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| true_guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=4.0 | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=4, | |
| maximum=16, | |
| step=1, | |
| value=8 | |
| ) | |
| num_images_per_prompt = gr.Slider( | |
| label="Output Count", | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| value=2 | |
| ) | |
| # Output Column | |
| with gr.Column(scale=2): | |
| result = gr.Gallery( | |
| label="Edited Images", | |
| columns=lambda x: min(x, 2), | |
| height=500, | |
| object_fit="cover", | |
| preview=True | |
| ) | |
| prompt_info = gr.HTML( | |
| value="<div style='padding:15px; margin-top:15px'>" | |
| "Prompt details will appear after generation</div>" | |
| ) | |
| # # Examples | |
| # gr.Examples( | |
| # examples=[ | |
| # "Change the background scene to a rooftop bar at night", | |
| # "Transform to pixel art style with 8-bit graphics", | |
| # "Replace all text with 'Qwen AI' in futuristic font" | |
| # ], | |
| # inputs=[prompt], | |
| # label="Sample Instructions", | |
| # cache_examples=True | |
| # ) | |
| # Set up processing | |
| inputs = [ | |
| input_image, | |
| prompt, | |
| seed, | |
| randomize_seed, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| rewrite_toggle, | |
| num_images_per_prompt | |
| ] | |
| outputs = [result, seed, prompt_info] | |
| run_button.click( | |
| fn=infer, | |
| inputs=inputs, | |
| outputs=outputs | |
| ) | |
| prompt.submit( | |
| fn=infer, | |
| inputs=inputs, | |
| outputs=outputs | |
| ) | |
| demo.queue(max_size=5).launch() |