samin commited on
Commit
c389fac
Β·
verified Β·
1 Parent(s): a4b9fca

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +111 -20
src/streamlit_app.py CHANGED
@@ -1,6 +1,8 @@
1
  # app.py
2
  import json
 
3
  from pathlib import Path
 
4
  import streamlit as st
5
 
6
  # ──────────────────────────────────────────────────────────────────────────────
@@ -33,32 +35,117 @@ if "show_counts" not in st.session_state:
33
  st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS}
34
 
35
  # ──────────────────────────────────────────────────────────────────────────────
36
- # Load manifest (GitHub RAW URLs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # ──────────────────────────────────────────────────────────────────────────────
38
  @st.cache_data(show_spinner=False)
39
  def load_manifest():
40
  with open(MANIFEST_PATH, "r", encoding="utf-8") as f:
41
  man = json.load(f)
42
- items = man["items"]
43
  per_task = {t: [] for t in DATA_DIRS}
44
  topics = set()
 
45
  for it in items:
 
 
 
 
 
 
46
  per_task.setdefault(it.get("task", "Unknown"), []).append(it)
47
  if it.get("topic"):
48
  topics.add(it["topic"])
 
 
 
 
 
49
  return per_task, sorted(list(topics))
50
 
51
  # ──────────────────────────────────────────────────────────────────────────────
52
- # Image grid
53
  # ──────────────────────────────────────────────────────────────────────────────
54
- def _display_images(urls, caption_prefix="", max_per_row=3):
55
  if not urls:
56
  st.write("No images found.")
57
  return
58
- cols = st.columns(min(len(urls), max_per_row))
 
 
59
  for i, url in enumerate(urls):
60
- with cols[i % max_per_row]:
61
- st.image(url, caption=f"{caption_prefix} {Path(url).name}", use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # ──────────────────────────────────────────────────────────────────────────────
64
  # App
@@ -67,21 +154,25 @@ def main():
67
  st.title("πŸ–ΌοΈ ImagenHub2 Data Visualization")
68
  st.markdown("Each task starts with **10** items β€” click **Show more** to load **+10**.")
69
 
70
- # Load manifest first to extract available topics
71
- with st.spinner("Loading manifest…"):
72
- per_task, all_topics = load_manifest()
73
-
74
- # Sidebar
75
  st.sidebar.header("Filters")
 
 
 
 
76
  selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS)
77
  search_query = st.sidebar.text_input("πŸ” Search in prompts", "")
78
- topic_filter = st.sidebar.multiselect("Select Topics", all_topics, default=[])
79
  subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "")
80
 
81
  st.sidebar.header("Task Descriptions")
82
  for t in selected_tasks:
83
  st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}")
84
 
 
 
 
 
85
  # Tabs per selected task
86
  tabs = st.tabs(selected_tasks) if selected_tasks else []
87
  for task, tab in zip(selected_tasks, tabs):
@@ -92,11 +183,11 @@ def main():
92
 
93
  # Apply filters
94
  def _match(it):
95
- sq = search_query.lower()
96
- if search_query and (sq not in it.get("prompt", "").lower()
97
- and sq not in it.get("prompt_refined", "").lower()):
98
  return False
99
- if topic_filter and it.get("topic", "") not in topic_filter:
100
  return False
101
  if subtopic_filter and it.get("subtopic", "") != subtopic_filter:
102
  return False
@@ -113,7 +204,7 @@ def main():
113
  for it in batch:
114
  header = f"**{it.get('item_id','?')}** β€” {it.get('topic','Unknown')} / {it.get('subtopic','Unknown')}"
115
  with st.expander(header, expanded=False):
116
- c1, c2, c3 = st.columns(3)
117
  with c1: st.write(f"**Task:** {it.get('task','Unknown')}")
118
  with c2: st.write(f"**Topic:** {it.get('topic','Unknown')}")
119
  with c3: st.write(f"**Subtopic:** {it.get('subtopic','Unknown')}")
@@ -131,10 +222,10 @@ def main():
131
 
132
  if cond_urls:
133
  st.write("**Condition Images:**")
134
- _display_images(cond_urls, "Condition")
135
  if model_urls:
136
  st.write("**Model Output:**")
137
- _display_images(model_urls, "Model")
138
  st.divider()
139
 
140
  # Pagination controls
 
1
  # app.py
2
  import json
3
+ import re
4
  from pathlib import Path
5
+ from urllib.parse import urlparse
6
  import streamlit as st
7
 
8
  # ──────────────────────────────────────────────────────────────────────────────
 
35
  st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS}
36
 
37
  # ──────────────────────────────────────────────────────────────────────────────
38
+ # Helpers: natural sorting for ids and filenames
39
+ # ──────────────────────────────────────────────────────────────────────────────
40
+ _num_re = re.compile(r"(\d+)")
41
+
42
+ def _natural_key(s: str):
43
+ parts = _num_re.split(s)
44
+ out = []
45
+ for p in parts:
46
+ if p.isdigit():
47
+ out.append(int(p))
48
+ else:
49
+ out.append(p.lower())
50
+ return out
51
+
52
+ def _basename_from_url(url: str) -> str:
53
+ try:
54
+ return Path(urlparse(url).path).name
55
+ except Exception:
56
+ return url
57
+
58
+ def _sorted_urls(urls):
59
+ return sorted(urls, key=lambda u: _natural_key(_basename_from_url(u)))
60
+
61
+ # ──────────────────────────────────────────────────────────────────────────────
62
+ # Load manifest (local file or baked into Space)
63
  # ──────────────────────────────────────────────────────────────────────────────
64
  @st.cache_data(show_spinner=False)
65
  def load_manifest():
66
  with open(MANIFEST_PATH, "r", encoding="utf-8") as f:
67
  man = json.load(f)
68
+ items = man.get("items", [])
69
  per_task = {t: [] for t in DATA_DIRS}
70
  topics = set()
71
+
72
  for it in items:
73
+ # sort the image url lists inside each item for determinism
74
+ if it.get("cond_image_urls"):
75
+ it["cond_image_urls"] = _sorted_urls(it["cond_image_urls"])
76
+ if it.get("model_output_urls"):
77
+ it["model_output_urls"] = _sorted_urls(it["model_output_urls"])
78
+
79
  per_task.setdefault(it.get("task", "Unknown"), []).append(it)
80
  if it.get("topic"):
81
  topics.add(it["topic"])
82
+
83
+ # sort items within each task by item_id (natural order)
84
+ for t, lst in per_task.items():
85
+ lst.sort(key=lambda it: _natural_key(str(it.get("item_id", ""))))
86
+
87
  return per_task, sorted(list(topics))
88
 
89
  # ──────────────────────────────────────────────────────────────────────────────
90
+ # Stable image grid
91
  # ──────────────────────────────────────────────────────────────────────────────
92
+ def _display_images(urls, caption_prefix="", max_per_row=3, fixed_height_px=None):
93
  if not urls:
94
  st.write("No images found.")
95
  return
96
+
97
+ # 3 equal columns, pin to top so text doesn’t jiggle vertically
98
+ cols = st.columns([1, 1, 1], vertical_alignment="top")
99
  for i, url in enumerate(urls):
100
+ col = cols[i % max_per_row]
101
+ with col:
102
+ if fixed_height_px:
103
+ # Reserve space and avoid reflow while image loads
104
+ st.markdown(
105
+ f"""
106
+ <div class="img-frame" style="height:{fixed_height_px}px; display:flex; align-items:center; justify-content:center; overflow:hidden; border-radius:12px;">
107
+ <img src="{url}" alt="{_basename_from_url(url)}" style="max-height:100%; width:100%; object-fit:contain;" />
108
+ </div>
109
+ <div class="img-cap" style="font-size:0.85rem; opacity:0.8; margin-top:4px;">
110
+ {caption_prefix} {_basename_from_url(url)}
111
+ </div>
112
+ """,
113
+ unsafe_allow_html=True,
114
+ )
115
+ else:
116
+ st.image(url, caption=f"{caption_prefix} {_basename_from_url(url)}", use_container_width=True)
117
+
118
+ # ──────────────────────────────────────────────────────────────────────────────
119
+ # Global CSS to reduce β€œvibrating” / layout reflow
120
+ # ──────────────────────────────────────────────────────────────────────────────
121
+ def _inject_css(fixed_height_px: int | None):
122
+ css = f"""
123
+ <style>
124
+ /* Keep base container tighter so jumps feel smaller */
125
+ .block-container {{ padding-top: 0.75rem; }}
126
+
127
+ /* Images rendered via st.image: constrain to container width and avoid overflow */
128
+ [data-testid="stImage"] img {{
129
+ width: 100%;
130
+ height: auto;
131
+ object-fit: contain;
132
+ display: block;
133
+ }}
134
+
135
+ /* Smooth out font jank on Spaces (fonts can swap) */
136
+ html * {{
137
+ -webkit-font-smoothing: antialiased;
138
+ -moz-osx-font-smoothing: grayscale;
139
+ text-rendering: optimizeLegibility;
140
+ }}
141
+
142
+ /* When using fixed frames (custom HTML), give them a subtle background so size is obvious */
143
+ .img-frame {{
144
+ background: rgba(0,0,0,0.03);
145
+ }}
146
+ </style>
147
+ """
148
+ st.markdown(css, unsafe_allow_html=True)
149
 
150
  # ──────────────────────────────────────────────────────────────────────────────
151
  # App
 
154
  st.title("πŸ–ΌοΈ ImagenHub2 Data Visualization")
155
  st.markdown("Each task starts with **10** items β€” click **Show more** to load **+10**.")
156
 
157
+ # Sidebar (put toggles before we render content)
 
 
 
 
158
  st.sidebar.header("Filters")
159
+ fixed_height_on = st.sidebar.toggle("Stabilize grid with fixed image height", value=True, help="Pre-allocate space for images to prevent page β€˜vibrating’.")
160
+ fixed_height_px = st.sidebar.number_input("Fixed image height (px)", min_value=120, max_value=1200, value=320, step=20, disabled=not fixed_height_on)
161
+ _inject_css(fixed_height_px if fixed_height_on else None)
162
+
163
  selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS)
164
  search_query = st.sidebar.text_input("πŸ” Search in prompts", "")
165
+ topic_filter = st.sidebar.text_input("Filter by topic id (exact match, optional)", "")
166
  subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "")
167
 
168
  st.sidebar.header("Task Descriptions")
169
  for t in selected_tasks:
170
  st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}")
171
 
172
+ # Load manifest
173
+ with st.spinner("Loading manifest…"):
174
+ per_task, _all_topics = load_manifest()
175
+
176
  # Tabs per selected task
177
  tabs = st.tabs(selected_tasks) if selected_tasks else []
178
  for task, tab in zip(selected_tasks, tabs):
 
183
 
184
  # Apply filters
185
  def _match(it):
186
+ sq = search_query.strip().lower()
187
+ if sq and (sq not in it.get("prompt", "").lower()
188
+ and sq not in it.get("prompt_refined", "").lower()):
189
  return False
190
+ if topic_filter and it.get("topic", "") != topic_filter:
191
  return False
192
  if subtopic_filter and it.get("subtopic", "") != subtopic_filter:
193
  return False
 
204
  for it in batch:
205
  header = f"**{it.get('item_id','?')}** β€” {it.get('topic','Unknown')} / {it.get('subtopic','Unknown')}"
206
  with st.expander(header, expanded=False):
207
+ c1, c2, c3 = st.columns([1,1,1], vertical_alignment="top")
208
  with c1: st.write(f"**Task:** {it.get('task','Unknown')}")
209
  with c2: st.write(f"**Topic:** {it.get('topic','Unknown')}")
210
  with c3: st.write(f"**Subtopic:** {it.get('subtopic','Unknown')}")
 
222
 
223
  if cond_urls:
224
  st.write("**Condition Images:**")
225
+ _display_images(cond_urls, "Condition", max_per_row=3, fixed_height_px=(fixed_height_px if fixed_height_on else None))
226
  if model_urls:
227
  st.write("**Model Output:**")
228
+ _display_images(model_urls, "Model", max_per_row=3, fixed_height_px=(fixed_height_px if fixed_height_on else None))
229
  st.divider()
230
 
231
  # Pagination controls