bradnow commited on
Commit
d055fb3
·
1 Parent(s): 9a0917a

Update to use a multimodal prompt input, and update api data processing to support images

Browse files
Files changed (2) hide show
  1. app.py +209 -11
  2. styles.css +9 -0
app.py CHANGED
@@ -1,12 +1,17 @@
1
  import random
 
2
  from uuid import uuid4
3
 
4
  from openai import OpenAI
5
  import gradio as gr
 
 
 
 
6
 
7
  from theme import apriel
8
  from utils import COMMUNITY_POSTFIX_URL, get_model_config, check_format, models_config, \
9
- logged_event_handler, DEBUG_MODE, DEBUG_MODEL, log_debug, log_info, log_error
10
  from log_chat import log_chat
11
 
12
  MODEL_TEMPERATURE = 0.8
@@ -119,14 +124,45 @@ def run_chat_inference(history, message, state):
119
  gr.Warning("Client UI is stale, please refresh the page")
120
  return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
121
 
 
 
 
122
  # outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn, session_state
123
  log_debug(f"{'-' * 80}")
124
  log_debug(f"chat_fn() --> Message: {message}")
125
  log_debug(f"chat_fn() --> History: {history}")
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  try:
128
  # Check if the message is empty
129
- if not message.strip():
130
  gr.Info("Please enter a message before sending")
131
  yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
132
  return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
@@ -140,7 +176,16 @@ def run_chat_inference(history, message, state):
140
  # Remove any assistant messages with metadata from history for multiple turns
141
  log_debug(f"Initial History: {history}")
142
  check_format(history, "messages")
143
- history.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
144
  log_debug(f"History with user message: {history}")
145
  check_format(history, "messages")
146
 
@@ -155,9 +200,152 @@ def run_chat_inference(history, message, state):
155
  check_format(history_no_thoughts, "messages")
156
  log_debug(f"history_no_thoughts with user message: {history_no_thoughts}")
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  stream = openai_client.chat.completions.create(
159
  model=model_name,
160
- messages=history_no_thoughts,
161
  temperature=MODEL_TEMPERATURE,
162
  stream=True
163
  )
@@ -322,20 +510,30 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
322
  )
323
 
324
  with gr.Row():
325
- with gr.Column(scale=10, min_width=400):
326
  with gr.Row():
327
- user_input = gr.Textbox(
 
 
 
 
328
  show_label=False,
329
- placeholder="Type your message here and press Enter",
330
- container=False
331
  )
 
 
 
 
 
 
 
332
  with gr.Column(scale=1, min_width=BUTTON_WIDTH * 2 + 20):
333
  with gr.Row():
334
  with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="send-button-container"):
335
- send_btn = gr.Button("Send", variant="primary")
336
- stop_btn = gr.Button("Stop", variant="cancel", visible=False)
337
  with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="clear-button-container"):
338
- clear_btn = gr.ClearButton(chatbot, value="New Chat", variant="secondary")
339
  with gr.Row():
340
  with gr.Column(min_width=400, elem_classes="opt-out-container"):
341
  with gr.Row():
 
1
  import random
2
+ from collections.abc import Mapping
3
  from uuid import uuid4
4
 
5
  from openai import OpenAI
6
  import gradio as gr
7
+ import base64
8
+ import mimetypes
9
+ import copy
10
+ import os
11
 
12
  from theme import apriel
13
  from utils import COMMUNITY_POSTFIX_URL, get_model_config, check_format, models_config, \
14
+ logged_event_handler, DEBUG_MODE, DEBUG_MODEL, log_debug, log_info, log_error, log_warning
15
  from log_chat import log_chat
16
 
17
  MODEL_TEMPERATURE = 0.8
 
124
  gr.Warning("Client UI is stale, please refresh the page")
125
  return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
126
 
127
+ # files will be the newly added files from the user
128
+ files = []
129
+
130
  # outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn, session_state
131
  log_debug(f"{'-' * 80}")
132
  log_debug(f"chat_fn() --> Message: {message}")
133
  log_debug(f"chat_fn() --> History: {history}")
134
 
135
+ # We have multimodal input in this case
136
+ if isinstance(message, Mapping):
137
+ files = message.get("files") or []
138
+ message = message.get("text") or ""
139
+ log_debug(f"chat_fn() --> Message (text only): {message}")
140
+ log_debug(f"chat_fn() --> Files: {files}")
141
+
142
+ # Validate that any uploaded files are images
143
+ if len(files) > 0:
144
+ invalid_files = []
145
+ for path in files:
146
+ try:
147
+ mime, _ = mimetypes.guess_type(path)
148
+ mime = mime or ""
149
+ if not mime.startswith("image/"):
150
+ invalid_files.append((os.path.basename(path), mime or "unknown"))
151
+ except Exception as e:
152
+ log_error(f"Failed to inspect file '{path}': {e}")
153
+ invalid_files.append((os.path.basename(path), "unknown"))
154
+
155
+ if invalid_files:
156
+ msg = "Only image files are allowed. Invalid uploads: " + \
157
+ ", ".join([f"{p} (type: {m})" for p, m in invalid_files])
158
+ log_warning(msg)
159
+ gr.Warning(msg)
160
+ yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
161
+ return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
162
+
163
  try:
164
  # Check if the message is empty
165
+ if not message.strip() and len(files) == 0:
166
  gr.Info("Please enter a message before sending")
167
  yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
168
  return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state
 
176
  # Remove any assistant messages with metadata from history for multiple turns
177
  log_debug(f"Initial History: {history}")
178
  check_format(history, "messages")
179
+ # Build UI history: add text (if any) and per-file image placeholders {"path": ...}
180
+ # Build API parts separately later to avoid Gradio issues with arrays in content
181
+ if len(files) == 0:
182
+ history.append({"role": "user", "content": message})
183
+ else:
184
+ if message.strip():
185
+ history.append({"role": "user", "content": message})
186
+ for path in files:
187
+ history.append({"role": "user", "content": {"path": path}})
188
+
189
  log_debug(f"History with user message: {history}")
190
  check_format(history, "messages")
191
 
 
200
  check_format(history_no_thoughts, "messages")
201
  log_debug(f"history_no_thoughts with user message: {history_no_thoughts}")
202
 
203
+ # Build API-specific messages:
204
+ # - Convert any UI image placeholders {"path": ...} to image_url parts
205
+ # - Convert any user string content that is a valid file path to image_url parts
206
+ # - Coalesce consecutive image paths into a single image-only user message
207
+ api_messages = []
208
+ image_parts_buffer = []
209
+
210
+ def flush_image_buffer():
211
+ if len(image_parts_buffer) > 0:
212
+ api_messages.append({"role": "user", "content": list(image_parts_buffer)})
213
+ image_parts_buffer.clear()
214
+
215
+ def to_image_part(path: str):
216
+ try:
217
+ mime, _ = mimetypes.guess_type(path)
218
+ mime = mime or "application/octet-stream"
219
+ with open(path, "rb") as f:
220
+ b64 = base64.b64encode(f.read()).decode("utf-8")
221
+ data_url = f"data:{mime};base64,{b64}"
222
+ return {"type": "image_url", "image_url": {"url": data_url}}
223
+ except Exception as e:
224
+ log_error(f"Failed to load file '{path}': {e}")
225
+ return None
226
+
227
+ def normalize_msg(msg):
228
+ # Returns (role, content, as_dict) where as_dict is a message dict suitable to pass through when unmodified
229
+ if isinstance(msg, dict):
230
+ return msg.get("role"), msg.get("content"), msg
231
+ # Gradio ChatMessage-like object
232
+ role = getattr(msg, "role", None)
233
+ content = getattr(msg, "content", None)
234
+ if role is not None:
235
+ return role, content, {"role": role, "content": content}
236
+ return None, None, msg
237
+
238
+ for m in copy.deepcopy(history_no_thoughts):
239
+ role, content, as_dict = normalize_msg(m)
240
+ # Unknown structure: pass through
241
+ if role is None:
242
+ flush_image_buffer()
243
+ api_messages.append(as_dict)
244
+ continue
245
+
246
+ # Assistant messages pass through as-is
247
+ if role == "assistant":
248
+ flush_image_buffer()
249
+ api_messages.append(as_dict)
250
+ continue
251
+
252
+ # Only user messages have potential image paths to convert
253
+ if role == "user":
254
+ # Case A: {'path': ...}
255
+ if isinstance(content, dict) and isinstance(content.get("path"), str):
256
+ p = content["path"]
257
+ part = to_image_part(p) if os.path.isfile(p) else None
258
+ if part:
259
+ image_parts_buffer.append(part)
260
+ else:
261
+ flush_image_buffer()
262
+ api_messages.append({"role": "user", "content": str(content)})
263
+ continue
264
+
265
+ # Case B: string or tuple content that may be a file path
266
+ if isinstance(content, str):
267
+ if os.path.isfile(content):
268
+ part = to_image_part(content)
269
+ if part:
270
+ image_parts_buffer.append(part)
271
+ continue
272
+ # Not a file path: pass through as text
273
+ flush_image_buffer()
274
+ api_messages.append({"role": "user", "content": content})
275
+ continue
276
+ if isinstance(content, tuple):
277
+ # Common case: a single-element tuple containing a path string
278
+ tuple_items = list(content)
279
+ tmp_parts = []
280
+ text_accum = []
281
+ for item in tuple_items:
282
+ if isinstance(item, str) and os.path.isfile(item):
283
+ part = to_image_part(item)
284
+ if part:
285
+ tmp_parts.append(part)
286
+ else:
287
+ text_accum.append(item)
288
+ else:
289
+ text_accum.append(str(item))
290
+ if tmp_parts:
291
+ flush_image_buffer()
292
+ api_messages.append({"role": "user", "content": tmp_parts})
293
+ if not text_accum:
294
+ continue
295
+ if text_accum:
296
+ flush_image_buffer()
297
+ api_messages.append({"role": "user", "content": "\n".join(text_accum)})
298
+ continue
299
+
300
+ # Case C: list content
301
+ if isinstance(content, list):
302
+ # If it's already a list of parts, let it pass through
303
+ all_dicts = all(isinstance(c, dict) for c in content)
304
+ if all_dicts:
305
+ flush_image_buffer()
306
+ api_messages.append({"role": "user", "content": content})
307
+ continue
308
+ # It might be a list of strings (paths/text). Convert string paths to image parts, others to text parts
309
+ tmp_parts = []
310
+ text_accum = []
311
+
312
+ def flush_text_accum():
313
+ if text_accum:
314
+ api_messages.append({"role": "user", "content": "\n".join(text_accum)})
315
+ text_accum.clear()
316
+ for item in content:
317
+ if isinstance(item, str) and os.path.isfile(item):
318
+ part = to_image_part(item)
319
+ if part:
320
+ tmp_parts.append(part)
321
+ else:
322
+ text_accum.append(item)
323
+ else:
324
+ text_accum.append(str(item))
325
+ if tmp_parts:
326
+ flush_image_buffer()
327
+ api_messages.append({"role": "user", "content": tmp_parts})
328
+ if text_accum:
329
+ flush_text_accum()
330
+ continue
331
+
332
+ # Fallback: pass through
333
+ flush_image_buffer()
334
+ api_messages.append(as_dict)
335
+ continue
336
+
337
+ # Other roles
338
+ flush_image_buffer()
339
+ api_messages.append(as_dict)
340
+
341
+ # Flush any trailing images
342
+ flush_image_buffer()
343
+
344
+ log_debug(f"sending api_messages to model {model_name}: {api_messages}")
345
+
346
  stream = openai_client.chat.completions.create(
347
  model=model_name,
348
+ messages=api_messages,
349
  temperature=MODEL_TEMPERATURE,
350
  stream=True
351
  )
 
510
  )
511
 
512
  with gr.Row():
513
+ with gr.Column(scale=10, min_width=400, elem_classes="user-input-container"):
514
  with gr.Row():
515
+ user_input = gr.MultimodalTextbox(
516
+ interactive=True,
517
+ container=False,
518
+ file_count="multiple",
519
+ placeholder="Type your message here and press Enter or upload file...",
520
  show_label=False,
521
+ sources=["upload"]
 
522
  )
523
+
524
+ # Original text-only input
525
+ # user_input = gr.Textbox(
526
+ # show_label=False,
527
+ # placeholder="Type your message here and press Enter",
528
+ # container=False
529
+ # )
530
  with gr.Column(scale=1, min_width=BUTTON_WIDTH * 2 + 20):
531
  with gr.Row():
532
  with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="send-button-container"):
533
+ send_btn = gr.Button("Send", variant="primary", elem_classes="control-button")
534
+ stop_btn = gr.Button("Stop", variant="cancel", elem_classes="control-button", visible=False)
535
  with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="clear-button-container"):
536
+ clear_btn = gr.ClearButton(chatbot, value="New Chat", variant="secondary", elem_classes="control-button")
537
  with gr.Row():
538
  with gr.Column(min_width=400, elem_classes="opt-out-container"):
539
  with gr.Row():
styles.css CHANGED
@@ -30,6 +30,15 @@
30
  max-height: 1400px;
31
  }
32
 
 
 
 
 
 
 
 
 
 
33
  button.cancel {
34
  border: var(--button-border-width) solid var(--button-cancel-border-color);
35
  background: var(--button-cancel-background-fill);
 
30
  max-height: 1400px;
31
  }
32
 
33
+ .user-input-container .multimodal-textbox{
34
+ border: none !important;
35
+ }
36
+
37
+ /* Match the height of the modified multimodal input box on the same row */
38
+ .control-button {
39
+ height: 51px;
40
+ }
41
+
42
  button.cancel {
43
  border: var(--button-border-width) solid var(--button-cancel-border-color);
44
  background: var(--button-cancel-background-fill);