ImagenWorld-Visualizer / src /streamlit_app.py
samin's picture
Update src/streamlit_app.py
f6a2362 verified
raw
history blame
8.37 kB
# app.py
import json
from pathlib import Path
import streamlit as st
# ──────────────────────────────────────────────────────────────────────────────
# Page config
# ──────────────────────────────────────────────────────────────────────────────
st.set_page_config(
page_title="ImagenHub2 Data Visualization",
page_icon="πŸ–ΌοΈ",
layout="wide",
initial_sidebar_state="expanded"
)
# ──────────────────────────────────────────────────────────────────────────────
# Basics
# ──────────────────────────────────────────────────────────────────────────────
MANIFEST_PATH = Path(__file__).parent / "manifest_v1.json"
DATA_DIRS = ["TIG", "TIE", "SRIG", "SRIE", "MRIG", "MRIE"]
TASK_DESCRIPTIONS = {
"TIG": "Text to Image Generation - Generate images from text prompts only",
"TIE": "Text and Image to Image Editing - Edit images based on text prompts",
"SRIG": "Single-Reference Image Generation - Generate images using text + single reference image",
"SRIE": "Single-Reference Image Editing - Edit images using text + single reference image",
"MRIG": "Multi-Reference Image Generation - Generate images using text + multiple reference images",
"MRIE": "Multi-Reference Image Editing - Edit images using text + multiple reference images"
}
DEFAULT_CHUNK = 10
if "show_counts" not in st.session_state:
st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS}
# ──────────────────────────────────────────────────────────────────────────────
# Load manifest (GitHub RAW URLs)
# ──────────────────────────────────────────────────────────────────────────────
@st.cache_data(show_spinner=False)
def load_manifest():
with open(MANIFEST_PATH, "r", encoding="utf-8") as f:
man = json.load(f)
items = man["items"]
per_task = {t: [] for t in DATA_DIRS}
topics = set()
for it in items:
per_task.setdefault(it.get("task", "Unknown"), []).append(it)
if it.get("topic"):
topics.add(it["topic"])
return per_task, sorted(list(topics))
# ──────────────────────────────────────────────────────────────────────────────
# Image grid
# ──────────────────────────────────────────────────────────────────────────────
def _display_images(urls, caption_prefix="", max_per_row=3):
if not urls:
st.write("No images found.")
return
cols = st.columns(min(len(urls), max_per_row))
for i, url in enumerate(urls):
with cols[i % max_per_row]:
st.image(url, caption=f"{caption_prefix} {Path(url).name}", use_container_width=True)
# ──────────────────────────────────────────────────────────────────────────────
# App
# ──────────────────────────────────────────────────────────────────────────────
def main():
st.title("πŸ–ΌοΈ ImagenHub2 Data Visualization")
st.markdown("Each task starts with **10** items β€” click **Show more** to load **+10**.")
# Load manifest first to extract available topics
with st.spinner("Loading manifest…"):
per_task, all_topics = load_manifest()
# Sidebar
st.sidebar.header("Filters")
selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS)
search_query = st.sidebar.text_input("πŸ” Search in prompts", "")
topic_filter = st.sidebar.multiselect("Select Topics", all_topics, default=[])
subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "")
st.sidebar.header("Task Descriptions")
for t in selected_tasks:
st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}")
# Tabs per selected task
tabs = st.tabs(selected_tasks) if selected_tasks else []
for task, tab in zip(selected_tasks, tabs):
with tab:
st.subheader(task)
limit = st.session_state.show_counts.get(task, DEFAULT_CHUNK)
all_items = per_task.get(task, [])
# Apply filters
def _match(it):
sq = search_query.lower()
if search_query and (sq not in it.get("prompt", "").lower()
and sq not in it.get("prompt_refined", "").lower()):
return False
if topic_filter and it.get("topic", "") not in topic_filter:
return False
if subtopic_filter and it.get("subtopic", "") != subtopic_filter:
return False
return True
filtered = [it for it in all_items if _match(it)]
batch = filtered[:limit]
st.caption(f"Showing {len(batch)} / {len(filtered)} (from {len(all_items)} items in {task})")
if not batch:
st.warning("No items match current filters.")
else:
for it in batch:
header = f"**{it.get('item_id','?')}** β€” {it.get('topic','Unknown')} / {it.get('subtopic','Unknown')}"
with st.expander(header, expanded=False):
c1, c2, c3 = st.columns(3)
with c1: st.write(f"**Task:** {it.get('task','Unknown')}")
with c2: st.write(f"**Topic:** {it.get('topic','Unknown')}")
with c3: st.write(f"**Subtopic:** {it.get('subtopic','Unknown')}")
st.write("**Original Prompt:**")
st.write(it.get("prompt", "β€”"))
if it.get("prompt_refined"):
st.write("**Refined Prompt:**")
st.write(it.get("prompt_refined", ""))
if it.get("remarks"):
st.write("**Remarks:**")
st.write(it.get("remarks", ""))
cond_urls = it.get("cond_image_urls", [])
model_urls = it.get("model_output_urls", [])
if cond_urls:
st.write("**Condition Images:**")
_display_images(cond_urls, "Condition")
if model_urls:
st.write("**Model Output:**")
_display_images(model_urls, "Model")
st.divider()
# Pagination controls
c_more, c_reset, c_info = st.columns(3)
with c_more:
if st.button("Show more", key=f"more_{task}"):
st.session_state.show_counts[task] = limit + DEFAULT_CHUNK
st.rerun()
with c_reset:
if st.button("Reset", key=f"reset_{task}"):
st.session_state.show_counts[task] = DEFAULT_CHUNK
st.rerun()
with c_info:
st.caption(f"Current limit: {limit}")
if __name__ == "__main__":
main()