DiT360 / app.py
Insta360-Research's picture
Update app.py
0f82d97 verified
import gradio as gr
import torch
import numpy as np
import random
from PIL import Image
import spaces
from src.pipeline import DiT360Pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_repo = "black-forest-labs/FLUX.1-dev"
lora_weights = "Insta360-Research/DiT360-Panorama-Image-Generation"
pipe = DiT360Pipeline.from_pretrained(model_repo, torch_dtype=torch_dtype).to(device)
pipe.load_lora_weights(lora_weights)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
@spaces.GPU
def infer(
prompt,
seed,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
generator = torch.Generator(device=device).manual_seed(int(seed))
full_prompt = f"This is a panorama. The images shows {prompt.strip()}"
image = pipe(
full_prompt,
width=2048,
height=1024,
num_inference_steps=num_inference_steps,
guidance_scale=2.8,
generator=generator,
).images[0]
image.save("test.png")
return image
def generate_seed():
return random.randint(0, MAX_SEED)
examples = [
# Outdoor
"A medieval castle stands proudly on a hilltop surrounded by autumn forests, with golden light spilling across the landscape.",
"A futuristic cityscape under a starry night sky.",
"A futuristic city skyline reflects on the calm river at sunset, neon lights glowing against the twilight sky.",
"A snowy mountain village under northern lights, with cozy cabins and smoke rising from chimneys.",
"A grand Gothic cathedral towers over a bustling European plaza.",
"A modern glass skyscraper district rises above green parks and wide boulevards.",
"A desert city with sandstone buildings glows warmly under the golden rays of the setting sun.",
"A Greek island village with whitewashed buildings and blue domes shines under the midday sun.",
"A futuristic floating city hovers above the sea, its architecture glowing with holographic lights.",
"A canyon with carved temples and ancient ruins stretches into the horizon.",
]
css = """
#main-container {
display: flex;
flex-direction: column;
gap: 2rem;
margin-top: 1rem;
}
#top-row {
display: flex;
flex-direction: row;
justify-content: center;
align-items: flex-start;
gap: 2rem;
}
#bottom-row {
display: flex;
flex-direction: row;
gap: 2rem;
}
#image-panel {
flex: 2;
max-width: 1200px;
margin: 0 auto;
}
#input-panel {
flex: 1;
}
#example-panel {
flex: 2;
}
#settings-panel {
flex: 1;
max-width: 280px;
}
#prompt-box textarea {
resize: none !important;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# πŸŒ€ DiT360: High-Fidelity Panoramic Image Generation
Here are our resources:
- πŸ’» **Code**: [https://github.com/Insta360-Research-Team/DiT360](https://github.com/Insta360-Research-Team/DiT360)
- 🌐 **Web Page**: [https://fenghora.github.io/DiT360-Page/](https://fenghora.github.io/DiT360-Page/)
- 🧠 **Pretrained Model**: [https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation](https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation)
- πŸ“Š **Dataset**: [https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation](https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation)
"""
)
gr.Markdown("Official Gradio demo for **[DiT360](https://fenghora.github.io/DiT360-Page/)**, a panoramic image generation model based on hybrid training.")
with gr.Row(elem_id="top-row"):
with gr.Column(elem_id="top-panel"):
result = gr.Image(label="Generated Panorama", show_label=False, type="pil", height=800)
prompt = gr.Textbox(
elem_id="prompt-box",
placeholder="Describe your panoramic scene here...",
show_label=False,
lines=2,
container=False,
)
run_button = gr.Button("Generate", variant="primary")
with gr.Row(elem_id="bottom-row"):
with gr.Column(elem_id="example-panel"):
gr.Markdown("### πŸ“š Examples")
gr.Examples(examples=examples, inputs=[prompt])
with gr.Column(elem_id="settings-panel"):
gr.Markdown("### βš™οΈ Settings")
gr.Markdown(
"For better results, the image **width and height are fixed** at 2048Γ—1024 (2:1 aspect ratio). "
)
seed_display = gr.Number(value=0, label="Seed", interactive=True)
random_seed_button = gr.Button("🎲 Random Seed")
random_seed_button.click(fn=generate_seed, inputs=[], outputs=seed_display)
num_inference_steps = gr.Slider(28, 100, value=50, step=1, label="Inference Steps")
gr.Markdown(
"πŸ’‘ *Tip: Try descriptive prompts like β€œA mountain village at sunrise with mist over the valley.” "
"DiT360 will automatically add a trigger word.*"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, seed_display, num_inference_steps],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()