|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
import streamlit as st |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="ImagenHub2 Data Visualization", |
|
|
page_icon="πΌοΈ", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MANIFEST_PATH = Path(__file__).parent / "manifest_v1.json" |
|
|
|
|
|
DATA_DIRS = ["TIG", "TIE", "SRIG", "SRIE", "MRIG", "MRIE"] |
|
|
TASK_DESCRIPTIONS = { |
|
|
"TIG": "Text to Image Generation - Generate images from text prompts only", |
|
|
"TIE": "Text and Image to Image Editing - Edit images based on text prompts", |
|
|
"SRIG": "Single-Reference Image Generation - Generate images using text + single reference image", |
|
|
"SRIE": "Single-Reference Image Editing - Edit images using text + single reference image", |
|
|
"MRIG": "Multi-Reference Image Generation - Generate images using text + multiple reference images", |
|
|
"MRIE": "Multi-Reference Image Editing - Edit images using text + multiple reference images" |
|
|
} |
|
|
|
|
|
DEFAULT_CHUNK = 10 |
|
|
if "show_counts" not in st.session_state: |
|
|
st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_data(show_spinner=False) |
|
|
def load_manifest(): |
|
|
with open(MANIFEST_PATH, "r", encoding="utf-8") as f: |
|
|
man = json.load(f) |
|
|
items = man["items"] |
|
|
per_task = {t: [] for t in DATA_DIRS} |
|
|
topics = set() |
|
|
for it in items: |
|
|
per_task.setdefault(it.get("task", "Unknown"), []).append(it) |
|
|
if it.get("topic"): |
|
|
topics.add(it["topic"]) |
|
|
return per_task, sorted(list(topics)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _display_images(urls, caption_prefix="", max_per_row=3): |
|
|
if not urls: |
|
|
st.write("No images found.") |
|
|
return |
|
|
cols = st.columns(min(len(urls), max_per_row)) |
|
|
for i, url in enumerate(urls): |
|
|
with cols[i % max_per_row]: |
|
|
st.image(url, caption=f"{caption_prefix} {Path(url).name}", use_container_width=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
st.title("πΌοΈ ImagenHub2 Data Visualization") |
|
|
st.markdown("Each task starts with **10** items β click **Show more** to load **+10**.") |
|
|
|
|
|
|
|
|
with st.spinner("Loading manifestβ¦"): |
|
|
per_task, all_topics = load_manifest() |
|
|
|
|
|
|
|
|
st.sidebar.header("Filters") |
|
|
selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS) |
|
|
search_query = st.sidebar.text_input("π Search in prompts", "") |
|
|
topic_filter = st.sidebar.multiselect("Select Topics", all_topics, default=[]) |
|
|
subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "") |
|
|
|
|
|
st.sidebar.header("Task Descriptions") |
|
|
for t in selected_tasks: |
|
|
st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}") |
|
|
|
|
|
|
|
|
tabs = st.tabs(selected_tasks) if selected_tasks else [] |
|
|
for task, tab in zip(selected_tasks, tabs): |
|
|
with tab: |
|
|
st.subheader(task) |
|
|
limit = st.session_state.show_counts.get(task, DEFAULT_CHUNK) |
|
|
all_items = per_task.get(task, []) |
|
|
|
|
|
|
|
|
def _match(it): |
|
|
sq = search_query.lower() |
|
|
if search_query and (sq not in it.get("prompt", "").lower() |
|
|
and sq not in it.get("prompt_refined", "").lower()): |
|
|
return False |
|
|
if topic_filter and it.get("topic", "") not in topic_filter: |
|
|
return False |
|
|
if subtopic_filter and it.get("subtopic", "") != subtopic_filter: |
|
|
return False |
|
|
return True |
|
|
|
|
|
filtered = [it for it in all_items if _match(it)] |
|
|
batch = filtered[:limit] |
|
|
|
|
|
st.caption(f"Showing {len(batch)} / {len(filtered)} (from {len(all_items)} items in {task})") |
|
|
|
|
|
if not batch: |
|
|
st.warning("No items match current filters.") |
|
|
else: |
|
|
for it in batch: |
|
|
header = f"**{it.get('item_id','?')}** β {it.get('topic','Unknown')} / {it.get('subtopic','Unknown')}" |
|
|
with st.expander(header, expanded=False): |
|
|
c1, c2, c3 = st.columns(3) |
|
|
with c1: st.write(f"**Task:** {it.get('task','Unknown')}") |
|
|
with c2: st.write(f"**Topic:** {it.get('topic','Unknown')}") |
|
|
with c3: st.write(f"**Subtopic:** {it.get('subtopic','Unknown')}") |
|
|
st.write("**Original Prompt:**") |
|
|
st.write(it.get("prompt", "β")) |
|
|
if it.get("prompt_refined"): |
|
|
st.write("**Refined Prompt:**") |
|
|
st.write(it.get("prompt_refined", "")) |
|
|
if it.get("remarks"): |
|
|
st.write("**Remarks:**") |
|
|
st.write(it.get("remarks", "")) |
|
|
|
|
|
cond_urls = it.get("cond_image_urls", []) |
|
|
model_urls = it.get("model_output_urls", []) |
|
|
|
|
|
if cond_urls: |
|
|
st.write("**Condition Images:**") |
|
|
_display_images(cond_urls, "Condition") |
|
|
if model_urls: |
|
|
st.write("**Model Output:**") |
|
|
_display_images(model_urls, "Model") |
|
|
st.divider() |
|
|
|
|
|
|
|
|
c_more, c_reset, c_info = st.columns(3) |
|
|
with c_more: |
|
|
if st.button("Show more", key=f"more_{task}"): |
|
|
st.session_state.show_counts[task] = limit + DEFAULT_CHUNK |
|
|
st.rerun() |
|
|
with c_reset: |
|
|
if st.button("Reset", key=f"reset_{task}"): |
|
|
st.session_state.show_counts[task] = DEFAULT_CHUNK |
|
|
st.rerun() |
|
|
with c_info: |
|
|
st.caption(f"Current limit: {limit}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|