samin commited on
Commit
f6a2362
Β·
verified Β·
1 Parent(s): 9b57b6b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +152 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,154 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import json
3
+ from pathlib import Path
4
  import streamlit as st
5
 
6
+ # ──────────────────────────────────────────────────────────────────────────────
7
+ # Page config
8
+ # ──────────────────────────────────────────────────────────────────────────────
9
+ st.set_page_config(
10
+ page_title="ImagenHub2 Data Visualization",
11
+ page_icon="πŸ–ΌοΈ",
12
+ layout="wide",
13
+ initial_sidebar_state="expanded"
14
+ )
15
+
16
+ # ──────────────────────────────────────────────────────────────────────────────
17
+ # Basics
18
+ # ──────────────────────────────────────────────────────────────────────────────
19
+ MANIFEST_PATH = Path(__file__).parent / "manifest_v1.json"
20
+
21
+ DATA_DIRS = ["TIG", "TIE", "SRIG", "SRIE", "MRIG", "MRIE"]
22
+ TASK_DESCRIPTIONS = {
23
+ "TIG": "Text to Image Generation - Generate images from text prompts only",
24
+ "TIE": "Text and Image to Image Editing - Edit images based on text prompts",
25
+ "SRIG": "Single-Reference Image Generation - Generate images using text + single reference image",
26
+ "SRIE": "Single-Reference Image Editing - Edit images using text + single reference image",
27
+ "MRIG": "Multi-Reference Image Generation - Generate images using text + multiple reference images",
28
+ "MRIE": "Multi-Reference Image Editing - Edit images using text + multiple reference images"
29
+ }
30
+
31
+ DEFAULT_CHUNK = 10
32
+ if "show_counts" not in st.session_state:
33
+ st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS}
34
+
35
+ # ──────────────────────────────────────────────────────────────────────────────
36
+ # Load manifest (GitHub RAW URLs)
37
+ # ──────────────────────────────────────────────────────────────────────────────
38
+ @st.cache_data(show_spinner=False)
39
+ def load_manifest():
40
+ with open(MANIFEST_PATH, "r", encoding="utf-8") as f:
41
+ man = json.load(f)
42
+ items = man["items"]
43
+ per_task = {t: [] for t in DATA_DIRS}
44
+ topics = set()
45
+ for it in items:
46
+ per_task.setdefault(it.get("task", "Unknown"), []).append(it)
47
+ if it.get("topic"):
48
+ topics.add(it["topic"])
49
+ return per_task, sorted(list(topics))
50
+
51
+ # ──────────────────────────────────────────────────────────────────────────────
52
+ # Image grid
53
+ # ──────────────────────────────────────────────────────────────────────────────
54
+ def _display_images(urls, caption_prefix="", max_per_row=3):
55
+ if not urls:
56
+ st.write("No images found.")
57
+ return
58
+ cols = st.columns(min(len(urls), max_per_row))
59
+ for i, url in enumerate(urls):
60
+ with cols[i % max_per_row]:
61
+ st.image(url, caption=f"{caption_prefix} {Path(url).name}", use_container_width=True)
62
+
63
+ # ──────────────────────────────────────────────────────────────────────────────
64
+ # App
65
+ # ──────────────────────────────────────────────────────────────────────────────
66
+ def main():
67
+ st.title("πŸ–ΌοΈ ImagenHub2 Data Visualization")
68
+ st.markdown("Each task starts with **10** items β€” click **Show more** to load **+10**.")
69
+
70
+ # Load manifest first to extract available topics
71
+ with st.spinner("Loading manifest…"):
72
+ per_task, all_topics = load_manifest()
73
+
74
+ # Sidebar
75
+ st.sidebar.header("Filters")
76
+ selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS)
77
+ search_query = st.sidebar.text_input("πŸ” Search in prompts", "")
78
+ topic_filter = st.sidebar.multiselect("Select Topics", all_topics, default=[])
79
+ subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "")
80
+
81
+ st.sidebar.header("Task Descriptions")
82
+ for t in selected_tasks:
83
+ st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}")
84
+
85
+ # Tabs per selected task
86
+ tabs = st.tabs(selected_tasks) if selected_tasks else []
87
+ for task, tab in zip(selected_tasks, tabs):
88
+ with tab:
89
+ st.subheader(task)
90
+ limit = st.session_state.show_counts.get(task, DEFAULT_CHUNK)
91
+ all_items = per_task.get(task, [])
92
+
93
+ # Apply filters
94
+ def _match(it):
95
+ sq = search_query.lower()
96
+ if search_query and (sq not in it.get("prompt", "").lower()
97
+ and sq not in it.get("prompt_refined", "").lower()):
98
+ return False
99
+ if topic_filter and it.get("topic", "") not in topic_filter:
100
+ return False
101
+ if subtopic_filter and it.get("subtopic", "") != subtopic_filter:
102
+ return False
103
+ return True
104
+
105
+ filtered = [it for it in all_items if _match(it)]
106
+ batch = filtered[:limit]
107
+
108
+ st.caption(f"Showing {len(batch)} / {len(filtered)} (from {len(all_items)} items in {task})")
109
+
110
+ if not batch:
111
+ st.warning("No items match current filters.")
112
+ else:
113
+ for it in batch:
114
+ header = f"**{it.get('item_id','?')}** β€” {it.get('topic','Unknown')} / {it.get('subtopic','Unknown')}"
115
+ with st.expander(header, expanded=False):
116
+ c1, c2, c3 = st.columns(3)
117
+ with c1: st.write(f"**Task:** {it.get('task','Unknown')}")
118
+ with c2: st.write(f"**Topic:** {it.get('topic','Unknown')}")
119
+ with c3: st.write(f"**Subtopic:** {it.get('subtopic','Unknown')}")
120
+ st.write("**Original Prompt:**")
121
+ st.write(it.get("prompt", "β€”"))
122
+ if it.get("prompt_refined"):
123
+ st.write("**Refined Prompt:**")
124
+ st.write(it.get("prompt_refined", ""))
125
+ if it.get("remarks"):
126
+ st.write("**Remarks:**")
127
+ st.write(it.get("remarks", ""))
128
+
129
+ cond_urls = it.get("cond_image_urls", [])
130
+ model_urls = it.get("model_output_urls", [])
131
+
132
+ if cond_urls:
133
+ st.write("**Condition Images:**")
134
+ _display_images(cond_urls, "Condition")
135
+ if model_urls:
136
+ st.write("**Model Output:**")
137
+ _display_images(model_urls, "Model")
138
+ st.divider()
139
+
140
+ # Pagination controls
141
+ c_more, c_reset, c_info = st.columns(3)
142
+ with c_more:
143
+ if st.button("Show more", key=f"more_{task}"):
144
+ st.session_state.show_counts[task] = limit + DEFAULT_CHUNK
145
+ st.rerun()
146
+ with c_reset:
147
+ if st.button("Reset", key=f"reset_{task}"):
148
+ st.session_state.show_counts[task] = DEFAULT_CHUNK
149
+ st.rerun()
150
+ with c_info:
151
+ st.caption(f"Current limit: {limit}")
152
+
153
+ if __name__ == "__main__":
154
+ main()