ImagenWorld-Visualizer / src /streamlit_app.py
samin's picture
Update src/streamlit_app.py
3951a6d verified
raw
history blame
12.8 kB
# app.py
import json
import re
from pathlib import Path
from urllib.parse import urlparse
import streamlit as st
# ──────────────────────────────────────────────────────────────────────────────
# Page config
# ──────────────────────────────────────────────────────────────────────────────
st.set_page_config(
page_title="ImagenHub2 Data Visualization",
page_icon="πŸ–ΌοΈ",
layout="wide",
initial_sidebar_state="expanded"
)
# ──────────────────────────────────────────────────────────────────────────────
# Basics
# ──────────────────────────────────────────────────────────────────────────────
MANIFEST_PATH = Path(__file__).parent / "manifest_v1.json"
DATA_DIRS = ["TIG", "TIE", "SRIG", "SRIE", "MRIG", "MRIE"]
TASK_DESCRIPTIONS = {
"TIG": "Text to Image Generation - Generate images from text prompts only",
"TIE": "Text and Image to Image Editing - Edit images based on text prompts",
"SRIG": "Single-Reference Image Generation - Generate images using text + single reference image",
"SRIE": "Single-Reference Image Editing - Edit images using text + single reference image",
"MRIG": "Multi-Reference Image Generation - Generate images using text + multiple reference images",
"MRIE": "Multi-Reference Image Editing - Edit images using text + multiple reference images"
}
DEFAULT_CHUNK = 10
if "show_counts" not in st.session_state:
st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS}
# ──────────────────────────────────────────────────────────────────────────────
# Helpers: natural sorting for ids and filenames
# ──────────────────────────────────────────────────────────────────────────────
_num_re = re.compile(r"(\d+)")
def _natural_key(s: str):
parts = _num_re.split(s)
out = []
for p in parts:
if p.isdigit():
out.append(int(p))
else:
out.append(p.lower())
return out
def _basename_from_url(url: str) -> str:
try:
return Path(urlparse(url).path).name
except Exception:
return url
def _sorted_urls(urls):
return sorted(urls, key=lambda u: _natural_key(_basename_from_url(u)))
# ──────────────────────────────────────────────────────────────────────────────
# Load manifest (local file or baked into Space)
# ──────────────────────────────────────────────────────────────────────────────
@st.cache_data(show_spinner=False)
def load_manifest():
with open(MANIFEST_PATH, "r", encoding="utf-8") as f:
man = json.load(f)
items = man.get("items", [])
per_task = {t: [] for t in DATA_DIRS}
topics = set()
for it in items:
# sort the image url lists inside each item for determinism
if it.get("cond_image_urls"):
it["cond_image_urls"] = _sorted_urls(it["cond_image_urls"])
if it.get("model_output_urls"):
it["model_output_urls"] = _sorted_urls(it["model_output_urls"])
per_task.setdefault(it.get("task", "Unknown"), []).append(it)
if it.get("topic"):
topics.add(it["topic"])
# sort items within each task by item_id (natural order)
for t, lst in per_task.items():
lst.sort(key=lambda it: _natural_key(str(it.get("item_id", ""))))
return per_task, sorted(list(topics))
# ──────────────────────────────────────────────────────────────────────────────
# Stable image grid
# ──────────────────────────────────────────────────────────────────────────────
def _display_images(urls, caption_prefix="", max_per_row=3, fixed_height_px=None):
if not urls:
st.write("No images found.")
return
cols = st.columns(max_per_row, vertical_alignment="top")
for i, url in enumerate(urls):
col = cols[i % max_per_row]
with col:
if fixed_height_px:
# Reserve space and avoid reflow while image loads
st.markdown(
f"""
<div class="img-frame" style="height:{fixed_height_px}px; display:flex; align-items:center; justify-content:center; overflow:hidden; border-radius:12px;">
<img src="{url}" alt="{_basename_from_url(url)}" style="max-height:100%; width:100%; object-fit:contain;" />
</div>
<div class="img-cap" style="font-size:0.85rem; opacity:0.8; margin-top:4px;">
{caption_prefix} {_basename_from_url(url)}
</div>
""",
unsafe_allow_html=True,
)
else:
st.image(url, caption=f"{caption_prefix} {_basename_from_url(url)}", use_container_width=True)
# ──────────────────────────────────────────────────────────────────────────────
# Global CSS to reduce β€œvibrating” / layout reflow
# ──────────────────────────────────────────────────────────────────────────────
def _inject_css(_fixed_height_px: int | None):
css = """
<style>
.block-container { padding-top: 0.75rem; }
[data-testid="stImage"] img {
width: 100%;
height: auto;
object-fit: contain;
display: block;
}
html * {
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-rendering: optimizeLegibility;
}
.img-frame { background: rgba(0,0,0,0.03); }
</style>
"""
st.markdown(css, unsafe_allow_html=True)
# ──────────────────────────────────────────────────────────────────────────────
# App
# ──────────────────────────────────────────────────────────────────────────────
def main():
st.title("πŸ–ΌοΈ ImagenHub2 Data Visualization")
st.markdown("Each task starts with **10** items β€” click **Show more** to load **+10**.")
# Load manifest first (to get topic list)
with st.spinner("Loading manifest…"):
per_task, _all_topics = load_manifest()
# Sidebar
st.sidebar.header("Filters")
fixed_height_on = st.sidebar.toggle(
"Stabilize grid with fixed image height",
value=True,
help="Pre-allocate space for images to prevent page β€˜vibrating’."
)
fixed_height_px = st.sidebar.number_input(
"Fixed image height (px)",
min_value=120, max_value=1200, value=320, step=20,
disabled=not fixed_height_on
)
_inject_css(fixed_height_px if fixed_height_on else None)
selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS)
search_query = st.sidebar.text_input("πŸ” Search in prompts", "")
# Topic filter behaves like task filter (multiselect)
topic_filter = st.sidebar.multiselect(
"Select Topics",
_all_topics,
default=[],
help="Filter items by one or more topic IDs."
)
subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "")
st.sidebar.header("Task Descriptions")
for t in selected_tasks:
st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}")
# Tabs per selected task
tabs = st.tabs(selected_tasks) if selected_tasks else []
for task, tab in zip(selected_tasks, tabs):
with tab:
st.subheader(task)
limit = st.session_state.show_counts.get(task, DEFAULT_CHUNK)
all_items = per_task.get(task, [])
# Apply filters
def _match(it):
sq = search_query.strip().lower()
if sq and (sq not in it.get("prompt", "").lower()
and sq not in it.get("prompt_refined", "").lower()):
return False
if topic_filter and it.get("topic", "") not in topic_filter:
return False
if subtopic_filter and it.get("subtopic", "") != subtopic_filter:
return False
return True
filtered = [it for it in all_items if _match(it)]
batch = filtered[:limit]
st.caption(f"Showing {len(batch)} / {len(filtered)} (from {len(all_items)} items in {task})")
if not batch:
st.warning("No items match current filters.")
else:
for it in batch:
header = f"**{it.get('item_id','?')}** β€” {it.get('topic','Unknown')} / {it.get('subtopic','Unknown')}"
with st.expander(header, expanded=False):
c1, c2, c3 = st.columns([1,1,1], vertical_alignment="top")
with c1: st.write(f"**Task:** {it.get('task','Unknown')}")
with c2: st.write(f"**Topic:** {it.get('topic','Unknown')}")
with c3: st.write(f"**Subtopic:** {it.get('subtopic','Unknown')}")
st.write("**Original Prompt:**")
st.write(it.get("prompt", "β€”"))
if it.get("prompt_refined"):
st.write("**Refined Prompt:**")
st.write(it.get("prompt_refined", ""))
if it.get("remarks"):
st.write("**Remarks:**")
st.write(it.get("remarks", ""))
cond_urls = it.get("cond_image_urls", [])
model_urls = it.get("model_output_urls", [])
if cond_urls:
st.write("**Condition Images:**")
_display_images(
cond_urls, "Condition", max_per_row=3,
fixed_height_px=(fixed_height_px if fixed_height_on else None)
)
if model_urls:
st.write("**Model Output:**")
_display_images(
model_urls, "Model", max_per_row=3,
fixed_height_px=(fixed_height_px if fixed_height_on else None)
)
st.divider()
# Pagination controls
c_more, c_reset, c_info = st.columns(3)
with c_more:
if st.button("Show more", key=f"more_{task}"):
st.session_state.show_counts[task] = limit + DEFAULT_CHUNK
st.rerun()
with c_reset:
if st.button("Reset", key=f"reset_{task}"):
st.session_state.show_counts[task] = DEFAULT_CHUNK
st.rerun()
with c_info:
st.caption(f"Current limit: {limit}")
if __name__ == "__main__":
main()