Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import json | |
| import time | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import gradio as gr | |
| from safetensors.torch import save_file | |
| from src.pipeline import FluxPipeline | |
| from src.transformer_flux import FluxTransformer2DModel | |
| from src.lora_helper import set_single_lora, set_multi_lora, unset_lora | |
| # Initialize the image processor | |
| base_path = "black-forest-labs/FLUX.1-dev" | |
| lora_base_path = "./models" | |
| # System prompt that will be hidden from users but automatically added to their input | |
| SYSTEM_PROMPT = "Ghibli Studio style, Charming hand-drawn anime-style illustration" | |
| pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16) | |
| transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16) | |
| pipe.transformer = transformer | |
| pipe.to("cuda") | |
| def clear_cache(transformer): | |
| for name, attn_processor in transformer.attn_processors.items(): | |
| attn_processor.bank_kv.clear() | |
| # Define the Gradio interface | |
| def single_condition_generate_image(user_prompt, spatial_img, height, width, seed): | |
| # Combine the system prompt with user prompt | |
| full_prompt = f"{SYSTEM_PROMPT}, {user_prompt}" if user_prompt else SYSTEM_PROMPT | |
| # Set the Ghibli LoRA | |
| lora_path = os.path.join(lora_base_path, "Ghibli.safetensors") | |
| set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) | |
| # Process the image | |
| spatial_imgs = [spatial_img] if spatial_img else [] | |
| image = pipe( | |
| full_prompt, | |
| height=int(height), | |
| width=int(width), | |
| guidance_scale=3.5, | |
| num_inference_steps=25, | |
| max_sequence_length=512, | |
| generator=torch.Generator("cpu").manual_seed(seed), | |
| subject_images=[], | |
| spatial_images=spatial_imgs, | |
| cond_size=512, | |
| ).images[0] | |
| clear_cache(pipe.transformer) | |
| return image | |
| # Load example images | |
| def load_examples(): | |
| examples = [] | |
| test_img_dir = "./test_imgs" | |
| example_prompts = [ | |
| "a cat sitting by the window", | |
| "a peaceful mountain village", | |
| "a young girl with flowers in her hair", | |
| "a magical forest with spirits", | |
| "a flying castle in the clouds", | |
| "a serene river with boats", | |
| "a cozy cottage in the countryside", | |
| "a bustling market in a small town" | |
| ] | |
| for i, filename in enumerate(["00.png", "02.png", "03.png", "04.png", "06.png", "07.png", "08.png", "09.png"]): | |
| img_path = os.path.join(test_img_dir, filename) | |
| if os.path.exists(img_path): | |
| # Use dimensions from original code for each specific example | |
| if filename == "00.png": | |
| height, width = 680, 1024 | |
| elif filename == "02.png": | |
| height, width = 560, 1024 | |
| elif filename == "03.png": | |
| height, width = 568, 1024 | |
| elif filename == "04.png": | |
| height, width = 768, 672 | |
| elif filename == "06.png": | |
| height, width = 896, 1024 | |
| elif filename == "07.png": | |
| height, width = 528, 800 | |
| elif filename == "08.png": | |
| height, width = 696, 1024 | |
| elif filename == "09.png": | |
| height, width = 896, 1024 | |
| else: | |
| height, width = 768, 768 | |
| examples.append([ | |
| example_prompts[i % len(example_prompts)], # User prompt (without system prompt) | |
| Image.open(img_path), # Reference image | |
| height, # Height | |
| width, # Width | |
| i + 1 # Seed | |
| ]) | |
| return examples | |
| # CSS for improved UI | |
| css = """ | |
| :root { | |
| --primary-color: #4a6670; | |
| --accent-color: #ff8a65; | |
| --background-color: #f5f5f5; | |
| --card-background: #ffffff; | |
| --text-color: #333333; | |
| --border-radius: 10px; | |
| --shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| } | |
| body { | |
| background-color: var(--background-color); | |
| color: var(--text-color); | |
| font-family: 'Helvetica Neue', Arial, sans-serif; | |
| } | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| } | |
| .gr-header { | |
| background: linear-gradient(135deg, #668796 0%, #4a6670 100%); | |
| padding: 24px; | |
| border-radius: var(--border-radius); | |
| margin-bottom: 24px; | |
| box-shadow: var(--shadow); | |
| text-align: center; | |
| } | |
| .gr-header h1 { | |
| color: white; | |
| font-size: 2.5rem; | |
| margin: 0; | |
| font-weight: 700; | |
| } | |
| .gr-header p { | |
| color: rgba(255, 255, 255, 0.9); | |
| font-size: 1.1rem; | |
| margin-top: 8px; | |
| } | |
| .gr-panel { | |
| background-color: var(--card-background); | |
| border-radius: var(--border-radius); | |
| padding: 16px; | |
| box-shadow: var(--shadow); | |
| } | |
| .gr-button { | |
| background-color: var(--accent-color); | |
| border: none; | |
| color: white; | |
| padding: 10px 20px; | |
| border-radius: 5px; | |
| font-size: 16px; | |
| font-weight: bold; | |
| cursor: pointer; | |
| transition: transform 0.1s, background-color 0.3s; | |
| } | |
| .gr-button:hover { | |
| background-color: #ff7043; | |
| transform: translateY(-2px); | |
| } | |
| .gr-input, .gr-select { | |
| border-radius: 5px; | |
| border: 1px solid #ddd; | |
| padding: 10px; | |
| width: 100%; | |
| } | |
| .gr-form { | |
| display: grid; | |
| gap: 16px; | |
| } | |
| .gr-box { | |
| background-color: var(--card-background); | |
| border-radius: var(--border-radius); | |
| padding: 20px; | |
| box-shadow: var(--shadow); | |
| margin-bottom: 20px; | |
| } | |
| .gr-gallery { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); | |
| gap: 16px; | |
| } | |
| .gr-gallery-item { | |
| overflow: hidden; | |
| border-radius: var(--border-radius); | |
| box-shadow: var(--shadow); | |
| transition: transform 0.3s; | |
| } | |
| .gr-gallery-item:hover { | |
| transform: scale(1.02); | |
| } | |
| .gr-image { | |
| width: 100%; | |
| height: auto; | |
| object-fit: cover; | |
| } | |
| .gr-footer { | |
| text-align: center; | |
| margin-top: 40px; | |
| padding: 20px; | |
| color: #666; | |
| font-size: 14px; | |
| } | |
| .gr-examples-gallery { | |
| margin-top: 20px; | |
| } | |
| /* Responsive adjustments */ | |
| @media (max-width: 768px) { | |
| .gr-header h1 { | |
| font-size: 1.8rem; | |
| } | |
| .gr-panel { | |
| padding: 12px; | |
| } | |
| } | |
| /* Ghibli-inspired accent colors */ | |
| .gr-accent-1 { | |
| background-color: #95ccd9; | |
| } | |
| .gr-accent-2 { | |
| background-color: #74ad8c; | |
| } | |
| .gr-accent-3 { | |
| background-color: #f9c06b; | |
| } | |
| """ | |
| # Create the Gradio Blocks interface | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML(""" | |
| <div class="gr-header"> | |
| <h1>✨ Ghibli Art Generator ✨</h1> | |
| <p>Transform your ideas into magical Ghibli-inspired artwork</p> | |
| </div> | |
| """) | |
| with gr.Tab("Create Ghibli Art"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div class="gr-box"> | |
| <h3>🎨 Your Creative Input</h3> | |
| <p>Describe what you want to see in your Ghibli-inspired image</p> | |
| </div> | |
| """) | |
| user_prompt = gr.Textbox( | |
| label="Your description", | |
| placeholder="Describe what you want to see (e.g., a cat sitting by the window)", | |
| lines=2 | |
| ) | |
| spatial_img = gr.Image( | |
| label="Reference Image (Optional)", | |
| type="pil", | |
| elem_classes="gr-image-upload" | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768) | |
| width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768) | |
| seed = gr.Slider(minimum=1, maximum=9999, step=1, label="Seed", value=42, | |
| info="Change for different variations") | |
| generate_btn = gr.Button("✨ Generate Ghibli Art", elem_classes="gr-button") | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div class="gr-box"> | |
| <h3>✨ Your Magical Creation</h3> | |
| <p>Your Ghibli-inspired artwork will appear here</p> | |
| </div> | |
| """) | |
| output_image = gr.Image(label="Generated Image", elem_classes="gr-output-image") | |
| gr.HTML(""" | |
| <div class="gr-box gr-examples-gallery"> | |
| <h3>✨ Inspiration Gallery</h3> | |
| <p>Click on any example to try it out</p> | |
| </div> | |
| """) | |
| # Add examples | |
| examples = load_examples() | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[user_prompt, spatial_img, height, width, seed], | |
| outputs=output_image, | |
| fn=single_condition_generate_image, | |
| cache_examples=False, | |
| examples_per_page=4 | |
| ) | |
| gr.HTML(""" | |
| <div class="gr-footer"> | |
| <p>Powered by FLUX.1 and Ghibli LoRA • Created with ❤️</p> | |
| </div> | |
| """) | |
| # Link the button to the function | |
| generate_btn.click( | |
| single_condition_generate_image, | |
| inputs=[user_prompt, spatial_img, height, width, seed], | |
| outputs=output_image | |
| ) | |
| # Launch the Gradio app | |
| demo.queue().launch() |