# app.py import json import re from pathlib import Path from urllib.parse import urlparse import streamlit as st # ────────────────────────────────────────────────────────────────────────────── # Page config # ────────────────────────────────────────────────────────────────────────────── st.set_page_config( page_title="ImagenHub2 Data Visualization", page_icon="🖼️", layout="wide", initial_sidebar_state="expanded" ) # ────────────────────────────────────────────────────────────────────────────── # Basics # ────────────────────────────────────────────────────────────────────────────── MANIFEST_PATH = Path(__file__).parent / "manifest_v1.json" DATA_DIRS = ["TIG", "TIE", "SRIG", "SRIE", "MRIG", "MRIE"] TASK_DESCRIPTIONS = { "TIG": "Text to Image Generation - Generate images from text prompts only", "TIE": "Text and Image to Image Editing - Edit images based on text prompts", "SRIG": "Single-Reference Image Generation - Generate images using text + single reference image", "SRIE": "Single-Reference Image Editing - Edit images using text + single reference image", "MRIG": "Multi-Reference Image Generation - Generate images using text + multiple reference images", "MRIE": "Multi-Reference Image Editing - Edit images using text + multiple reference images" } DEFAULT_CHUNK = 10 if "show_counts" not in st.session_state: st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS} # ────────────────────────────────────────────────────────────────────────────── # Helpers: natural sorting for ids and filenames # ────────────────────────────────────────────────────────────────────────────── _num_re = re.compile(r"(\d+)") def _natural_key(s: str): parts = _num_re.split(s) out = [] for p in parts: if p.isdigit(): out.append(int(p)) else: out.append(p.lower()) return out def _basename_from_url(url: str) -> str: try: return Path(urlparse(url).path).name except Exception: return url def _sorted_urls(urls): return sorted(urls, key=lambda u: _natural_key(_basename_from_url(u))) # ────────────────────────────────────────────────────────────────────────────── # Load manifest (local file or baked into Space) # ────────────────────────────────────────────────────────────────────────────── @st.cache_data(show_spinner=False) def load_manifest(): with open(MANIFEST_PATH, "r", encoding="utf-8") as f: man = json.load(f) items = man.get("items", []) per_task = {t: [] for t in DATA_DIRS} topics = set() for it in items: # sort the image url lists inside each item for determinism if it.get("cond_image_urls"): it["cond_image_urls"] = _sorted_urls(it["cond_image_urls"]) if it.get("model_output_urls"): it["model_output_urls"] = _sorted_urls(it["model_output_urls"]) per_task.setdefault(it.get("task", "Unknown"), []).append(it) if it.get("topic"): topics.add(it["topic"]) # sort items within each task by item_id (natural order) for t, lst in per_task.items(): lst.sort(key=lambda it: _natural_key(str(it.get("item_id", "")))) return per_task, sorted(list(topics)) # ────────────────────────────────────────────────────────────────────────────── # Stable image grid # ────────────────────────────────────────────────────────────────────────────── def _display_images(urls, caption_prefix="", max_per_row=3, fixed_height_px=None): if not urls: st.write("No images found.") return cols = st.columns(max_per_row, vertical_alignment="top") for i, url in enumerate(urls): col = cols[i % max_per_row] with col: if fixed_height_px: # Reserve space and avoid reflow while image loads st.markdown( f"""
{_basename_from_url(url)}
{caption_prefix} {_basename_from_url(url)}
""", unsafe_allow_html=True, ) else: st.image(url, caption=f"{caption_prefix} {_basename_from_url(url)}", use_container_width=True) # ────────────────────────────────────────────────────────────────────────────── # Global CSS to reduce “vibrating” / layout reflow # ────────────────────────────────────────────────────────────────────────────── def _inject_css(_fixed_height_px: int | None): css = """ """ st.markdown(css, unsafe_allow_html=True) # ────────────────────────────────────────────────────────────────────────────── # App # ────────────────────────────────────────────────────────────────────────────── def main(): st.title("🖼️ ImagenHub2 Data Visualization") st.markdown("Each task starts with **10** items — click **Show more** to load **+10**.") # Load manifest first (to get topic list) with st.spinner("Loading manifest…"): per_task, _all_topics = load_manifest() # Sidebar st.sidebar.header("Filters") fixed_height_on = st.sidebar.toggle( "Stabilize grid with fixed image height", value=True, help="Pre-allocate space for images to prevent page ‘vibrating’." ) fixed_height_px = st.sidebar.number_input( "Fixed image height (px)", min_value=120, max_value=1200, value=320, step=20, disabled=not fixed_height_on ) _inject_css(fixed_height_px if fixed_height_on else None) selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS) search_query = st.sidebar.text_input("🔍 Search in prompts", "") # Topic filter behaves like task filter (multiselect) topic_filter = st.sidebar.multiselect( "Select Topics", _all_topics, default=[], help="Filter items by one or more topic IDs." ) subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "") st.sidebar.header("Task Descriptions") for t in selected_tasks: st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}") # Tabs per selected task tabs = st.tabs(selected_tasks) if selected_tasks else [] for task, tab in zip(selected_tasks, tabs): with tab: st.subheader(task) limit = st.session_state.show_counts.get(task, DEFAULT_CHUNK) all_items = per_task.get(task, []) # Apply filters def _match(it): sq = search_query.strip().lower() if sq and (sq not in it.get("prompt", "").lower() and sq not in it.get("prompt_refined", "").lower()): return False if topic_filter and it.get("topic", "") not in topic_filter: return False if subtopic_filter and it.get("subtopic", "") != subtopic_filter: return False return True filtered = [it for it in all_items if _match(it)] batch = filtered[:limit] st.caption(f"Showing {len(batch)} / {len(filtered)} (from {len(all_items)} items in {task})") if not batch: st.warning("No items match current filters.") else: for it in batch: header = f"**{it.get('item_id','?')}** — {it.get('topic','Unknown')} / {it.get('subtopic','Unknown')}" with st.expander(header, expanded=False): c1, c2, c3 = st.columns([1,1,1], vertical_alignment="top") with c1: st.write(f"**Task:** {it.get('task','Unknown')}") with c2: st.write(f"**Topic:** {it.get('topic','Unknown')}") with c3: st.write(f"**Subtopic:** {it.get('subtopic','Unknown')}") st.write("**Original Prompt:**") st.write(it.get("prompt", "—")) if it.get("prompt_refined"): st.write("**Refined Prompt:**") st.write(it.get("prompt_refined", "")) if it.get("remarks"): st.write("**Remarks:**") st.write(it.get("remarks", "")) cond_urls = it.get("cond_image_urls", []) model_urls = it.get("model_output_urls", []) if cond_urls: st.write("**Condition Images:**") _display_images( cond_urls, "Condition", max_per_row=3, fixed_height_px=(fixed_height_px if fixed_height_on else None) ) if model_urls: st.write("**Model Output:**") _display_images( model_urls, "Model", max_per_row=3, fixed_height_px=(fixed_height_px if fixed_height_on else None) ) st.divider() # Pagination controls c_more, c_reset, c_info = st.columns(3) with c_more: if st.button("Show more", key=f"more_{task}"): st.session_state.show_counts[task] = limit + DEFAULT_CHUNK st.rerun() with c_reset: if st.button("Reset", key=f"reset_{task}"): st.session_state.show_counts[task] = DEFAULT_CHUNK st.rerun() with c_info: st.caption(f"Current limit: {limit}") if __name__ == "__main__": main()