prithivMLmods commited on
Commit
9916e82
·
verified ·
1 Parent(s): 49b03ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -157
app.py CHANGED
@@ -23,15 +23,11 @@ from transformers import (
23
  )
24
  from transformers.image_utils import load_image
25
 
26
- # These imports seem to be from a custom library.
27
- # If you have 'docling_core' installed, you can uncomment them.
28
- # from docling_core.types.doc import DoclingDocument, DocTagsDocument
29
-
30
  import re
31
  import ast
32
  import html
33
 
34
- # --- Constants ---
35
  MAX_MAX_NEW_TOKENS = 5120
36
  DEFAULT_MAX_NEW_TOKENS = 3072
37
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -90,7 +86,6 @@ model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90
  torch_dtype=torch.float16
91
  ).to(device).eval()
92
 
93
-
94
  # --- Preprocessing and Helper Functions ---
95
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
96
  """Add random padding to an image based on its size."""
@@ -150,9 +145,9 @@ def format_smoldocling_output(buffer_text, images):
150
  doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
151
  doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
152
  markdown_output = doc.export_to_markdown()
153
- return markdown_output
154
  # Fallback if library is not available or tags are not present
155
- return cleaned_output
156
 
157
  # --- Core Generation Logic ---
158
  def get_model_and_processor(model_name):
@@ -170,52 +165,33 @@ def get_model_and_processor(model_name):
170
  else:
171
  return None, None
172
 
173
- def is_video_file(filepath):
174
- """Check if a file has a common video extension."""
175
- if not filepath:
176
- return False
177
- video_extensions = ['.mp4', '.mov', '.avi', '.mkv', '.webm']
178
- return any(filepath.lower().endswith(ext) for ext in video_extensions)
179
-
180
  @spaces.GPU
181
- def generate_response(
182
- media_file: str,
183
- query: str,
184
- model_name: str,
185
- max_new_tokens: int,
186
- temperature: float,
187
- top_p: float
188
- ):
189
  """Unified generation function for both image and video."""
190
- if media_file is None:
191
- yield "Please upload an image or video file first."
192
- return
193
-
194
  processor, model = get_model_and_processor(model_name)
195
  if not processor or not model:
196
- yield "Invalid model selected."
197
  return
198
 
199
- media_type = "video" if is_video_file(media_file) else "image"
200
-
201
- try:
202
- if media_type == "video":
203
- frames = downsample_video(media_file)
204
- images = [frame for frame, _ in frames]
205
- else: # image
206
- images = [Image.open(media_file)]
207
- except Exception as e:
208
- yield f"Error processing file: {e}"
209
  return
210
 
 
 
 
 
 
 
211
  if model_name == "SmolDocling-256M-preview":
212
- if "OTSL" in query or "code" in query:
213
  images = [add_random_padding(img) for img in images]
214
- if "OCR at text at" in query or "Identify element" in query or "formula" in query:
215
- query = normalize_values(query, target_max=500)
216
 
217
  messages = [
218
- {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": query}]}
219
  ]
220
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
221
  inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
@@ -227,153 +203,346 @@ def generate_response(
227
  "max_new_tokens": max_new_tokens,
228
  "temperature": temperature,
229
  "top_p": top_p,
 
 
230
  }
231
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
232
  thread.start()
233
 
234
  buffer = ""
235
  for new_text in streamer:
236
- buffer += new_text.replace("<|im_end|>", "")
237
- yield buffer
238
 
239
  if model_name == "SmolDocling-256M-preview":
240
- formatted_output = format_smoldocling_output(buffer, images)
241
- yield formatted_output
242
  else:
243
- yield buffer.strip()
 
244
 
245
- # --- Gradio Interface ---
 
 
 
 
246
 
247
  # --- Examples ---
248
  image_examples = [
249
- ["images/0.png", "Reconstruct the doc [table] as it is."],
250
- ["images/8.png", "Describe the image!"],
251
- ["images/2.jpg", "OCR the image"],
252
- ["images/1.png", "Convert this page to docling"],
253
- ["images/3.png", "Convert this page to docling"],
254
- ["images/4.png", "Convert chart to OTSL."],
255
- ["images/5.jpg", "Convert code to text"],
256
- ["images/6.jpg", "Convert this table to OTSL."],
257
- ["images/7.jpg", "Convert formula to latex."],
258
  ]
 
259
  video_examples = [
260
- ["videos/1.mp4", "Explain the video in detail."],
261
- ["videos/2.mp4", "Explain the video in detail."]
262
  ]
263
- all_examples = image_examples + video_examples
264
-
265
 
266
- # --- UI Styling and Helper Functions ---
267
  css = """
268
- body, .gradio-container { font-family: 'Inter', sans-serif; }
269
- .main-container { padding: 20px; }
270
- .sidebar { background-color: #F7F7F7; border-right: 1px solid #E0E0E0; padding: 15px; border-radius: 15px; }
271
- .chat-window { min-height: 60vh; border: 1px solid #E0E0E0; border-radius: 15px; padding: 20px; box-shadow: 0 4px 8px rgba(0,0,0,0.05); }
272
- .input-bar { padding: 10px; border-radius: 15px; background-color: #FFFFFF; border: 1px solid #E0E0E0; margin-top: 20px;}
273
- .submit-button { background-color: #007AFF !important; color: white !important; font-weight: bold !important; }
274
- .media-display {text-align: center; background-color: #F0F0F0; border-radius: 10px; padding: 10px; margin-bottom: 20px;}
275
- .media-display img, .media-display video {max-height: 400px; margin: auto;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  """
277
 
278
- def handle_file_upload(file):
279
- if file is None:
280
- return None, gr.update(visible=False), gr.update(visible=False)
281
- if is_video_file(file.name):
282
- return file.name, gr.update(visible=False), gr.update(value=file.name, visible=True)
283
- else:
284
- return file.name, gr.update(value=file.name, visible=True), gr.update(visible=False)
285
-
286
- def handle_example_click(file_path, query):
287
- if is_video_file(file_path):
288
- # Update state, hide image, show video, update query
289
- return file_path, gr.update(visible=False), gr.update(value=file_path, visible=True), query
290
- else:
291
- # Update state, show image, hide video, update query
292
- return file_path, gr.update(value=file_path, visible=True), gr.update(visible=False), query
293
-
294
- def clear_all():
295
- return None, gr.update(visible=False), gr.update(visible=False), "### Output will be shown here", ""
296
-
297
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
298
- media_file_path = gr.State(None)
299
 
300
- with gr.Row(elem_classes="main-container"):
301
- # --- Sidebar ---
302
- with gr.Column(scale=1, elem_classes="sidebar"):
303
- gr.Markdown("### OCR Conversations")
304
- add_conv_btn = gr.Button("+ Add Conversation")
305
- gr.Markdown("---")
306
- gr.Markdown("#### Advanced Options")
307
- with gr.Accordion("⚙️ Generation Settings", open=False):
308
- max_new_tokens = gr.Slider(
309
- label="Max New Tokens",
310
- minimum=256,
311
- maximum=MAX_MAX_NEW_TOKENS,
312
- step=64,
313
- value=DEFAULT_MAX_NEW_TOKENS,
314
- )
315
- temperature = gr.Slider(
316
- label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.6
317
- )
318
- top_p = gr.Slider(
319
- label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9
320
- )
321
-
322
- # --- Main Content Panel ---
323
- with gr.Column(scale=4):
324
- gr.Markdown("# Multimodal OCR")
325
 
326
- with gr.Column(elem_classes="media-display"):
327
- image_display = gr.Image(type="filepath", label="Image Preview", visible=False)
328
- video_display = gr.Video(label="Video Preview", visible=False)
329
- gr.Markdown("Upload an image or video to begin.")
330
-
331
- output_display = gr.Markdown(elem_classes="chat-window", value="### Output will be shown here")
332
-
333
- # --- Input Bar ---
334
- with gr.Row(elem_classes="input-bar"): # Removed vertical=False
335
- upload_btn = gr.UploadButton("📁 Add Files", file_types=["image", "video"])
336
- model_dropdown = gr.Dropdown(
337
- choices=["Nanonets-OCR-s", "MonkeyOCR-Recognition", "Thyme-RL", "Typhoon-OCR-7B", "SmolDocling-256M-preview"],
338
- label="Select Model",
339
- value="Nanonets-OCR-s"
340
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  query_input = gr.Textbox(
342
- placeholder="Enter your query here...",
343
- show_label=False,
344
- scale=4,
345
  )
346
- submit_btn = gr.Button("▶", elem_classes="submit-button")
347
 
348
- # --- Examples defined after all components exist ---
349
- gr.Examples(
350
- examples=all_examples,
351
- inputs=[media_file_path, query_input],
352
- outputs=[media_file_path, image_display, video_display, query_input],
353
- fn=handle_example_click,
354
- label="Examples (Click to run)",
355
- cache_examples=True
356
- )
357
-
 
 
 
 
 
 
358
 
359
  # --- Event Handlers ---
360
- upload_btn.upload(
 
 
 
 
 
 
 
 
 
 
361
  fn=handle_file_upload,
362
- inputs=[upload_btn],
363
- outputs=[media_file_path, image_display, video_display]
364
  )
365
-
366
- submit_btn.click(
367
- fn=generate_response,
368
- inputs=[media_file_path, query_input, model_dropdown, max_new_tokens, temperature, top_p],
369
- outputs=[output_display]
 
 
 
 
370
  )
371
 
372
- add_conv_btn.click(
373
- fn=clear_all,
374
- outputs=[media_file_path, image_display, video_display, output_display, query_input]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  )
376
 
377
-
378
  if __name__ == "__main__":
379
  demo.queue(max_size=50).launch(share=True, show_error=True)
 
23
  )
24
  from transformers.image_utils import load_image
25
 
 
 
 
 
26
  import re
27
  import ast
28
  import html
29
 
30
+ # Constants for text generation
31
  MAX_MAX_NEW_TOKENS = 5120
32
  DEFAULT_MAX_NEW_TOKENS = 3072
33
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
86
  torch_dtype=torch.float16
87
  ).to(device).eval()
88
 
 
89
  # --- Preprocessing and Helper Functions ---
90
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
91
  """Add random padding to an image based on its size."""
 
145
  doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
146
  doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
147
  markdown_output = doc.export_to_markdown()
148
+ return buffer_text, markdown_output
149
  # Fallback if library is not available or tags are not present
150
+ return buffer_text, cleaned_output
151
 
152
  # --- Core Generation Logic ---
153
  def get_model_and_processor(model_name):
 
165
  else:
166
  return None, None
167
 
 
 
 
 
 
 
 
168
  @spaces.GPU
169
+ def generate_response(model_name: str, text: str, media_input, media_type: str,
170
+ max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
 
 
 
 
 
 
171
  """Unified generation function for both image and video."""
 
 
 
 
172
  processor, model = get_model_and_processor(model_name)
173
  if not processor or not model:
174
+ yield "Invalid model selected.", "Invalid model selected."
175
  return
176
 
177
+ if media_input is None:
178
+ yield f"Please upload a {media_type}.", f"Please upload a {media_type}."
 
 
 
 
 
 
 
 
179
  return
180
 
181
+ if media_type == "video":
182
+ frames = downsample_video(media_input)
183
+ images = [frame for frame, _ in frames]
184
+ else: # image
185
+ images = [media_input]
186
+
187
  if model_name == "SmolDocling-256M-preview":
188
+ if "OTSL" in text or "code" in text:
189
  images = [add_random_padding(img) for img in images]
190
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
191
+ text = normalize_values(text, target_max=500)
192
 
193
  messages = [
194
+ {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]}
195
  ]
196
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
197
  inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
 
203
  "max_new_tokens": max_new_tokens,
204
  "temperature": temperature,
205
  "top_p": top_p,
206
+ "top_k": top_k,
207
+ "repetition_penalty": repetition_penalty,
208
  }
209
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
210
  thread.start()
211
 
212
  buffer = ""
213
  for new_text in streamer:
214
+ buffer += new_text.replace("", "")
215
+ yield buffer, buffer
216
 
217
  if model_name == "SmolDocling-256M-preview":
218
+ raw_output, formatted_output = format_smoldocling_output(buffer, images)
219
+ yield raw_output, formatted_output
220
  else:
221
+ # For other models, the formatted output is just the cleaned buffer
222
+ yield buffer, buffer.strip()
223
 
224
+ def generate_image_wrapper(*args):
225
+ yield from generate_response(*args, media_type="image")
226
+
227
+ def generate_video_wrapper(*args):
228
+ yield from generate_response(*args, media_type="video")
229
 
230
  # --- Examples ---
231
  image_examples = [
232
+ ["Reconstruct the doc [table] as it is.", "images/0.png"],
233
+ ["Describe the image!", "images/8.png"],
234
+ ["OCR the image", "images/2.jpg"],
235
+ ["Convert this page to docling", "images/1.png"],
236
+ ["Convert this page to docling", "images/3.png"],
237
+ ["Convert chart to OTSL.", "images/4.png"],
238
+ ["Convert code to text", "images/5.jpg"],
239
+ ["Convert this table to OTSL.", "images/6.jpg"],
240
+ ["Convert formula to latex.", "images/7.jpg"],
241
  ]
242
+
243
  video_examples = [
244
+ ["Explain the video in detail.", "videos/1.mp4"],
245
+ ["Explain the video in detail.", "videos/2.mp4"]
246
  ]
 
 
247
 
248
+ # --- Custom CSS for the new UI ---
249
  css = """
250
+ /* Left sidebar styles */
251
+ .sidebar {
252
+ background-color: #f8f9fa;
253
+ border-right: 1px solid #e9ecef;
254
+ padding: 20px;
255
+ height: 100vh;
256
+ }
257
+
258
+ /* Main content area */
259
+ .content-area {
260
+ padding: 20px;
261
+ }
262
+
263
+ /* Document grid */
264
+ .doc-grid {
265
+ display: grid;
266
+ grid-template-columns: repeat(5, 1fr);
267
+ gap: 10px;
268
+ margin: 20px 0;
269
+ }
270
+
271
+ .doc-item {
272
+ border: 1px solid #dee2e6;
273
+ border-radius: 8px;
274
+ padding: 10px;
275
+ text-align: center;
276
+ height: 120px;
277
+ background-color: #f8f9fa;
278
+ cursor: pointer;
279
+ transition: all 0.2s ease;
280
+ }
281
+
282
+ .doc-item:hover {
283
+ border-color: #007bff;
284
+ background-color: #e9f0ff;
285
+ }
286
+
287
+ /* Upload and controls area */
288
+ .upload-controls {
289
+ display: flex;
290
+ align-items: center;
291
+ gap: 10px;
292
+ margin: 20px 0;
293
+ padding: 15px;
294
+ border: 1px solid #e9ecef;
295
+ border-radius: 8px;
296
+ }
297
+
298
+ .file-upload {
299
+ flex: 1;
300
+ }
301
+
302
+ .model-dropdown {
303
+ width: 200px;
304
+ }
305
+
306
+ .submit-btn {
307
+ background-color: #007bff;
308
+ color: white;
309
+ border: none;
310
+ border-radius: 4px;
311
+ padding: 10px 20px;
312
+ font-size: 1.2rem;
313
+ cursor: pointer;
314
+ transition: background-color 0.2s;
315
+ }
316
+
317
+ .submit-btn:hover {
318
+ background-color: #0069d9;
319
+ }
320
+
321
+ /* Output area */
322
+ .output-area {
323
+ margin-top: 20px;
324
+ }
325
+
326
+ /* Add conversation button */
327
+ .add-conv-btn {
328
+ background-color: #28a745;
329
+ color: white;
330
+ border: none;
331
+ padding: 8px 15px;
332
+ border-radius: 4px;
333
+ cursor: pointer;
334
+ }
335
+
336
+ .add-conv-btn:hover {
337
+ background-color: #218838;
338
+ }
339
+
340
+ /* Examples section */
341
+ .examples-section {
342
+ margin-top: 20px;
343
+ }
344
+
345
+ /* Header styles */
346
+ .header {
347
+ margin-bottom: 15px;
348
+ }
349
+
350
+ /* Media upload icon styling */
351
+ .upload-icon {
352
+ font-size: 1.5rem;
353
+ color: #6c757d;
354
+ margin-right: 10px;
355
+ }
356
+
357
+ /* Document icon styling */
358
+ .doc-icon {
359
+ font-size: 2rem;
360
+ color: #6c757d;
361
+ margin-bottom: 5px;
362
+ }
363
+
364
+ /* Query input */
365
+ .query-input {
366
+ margin: 15px 0;
367
+ }
368
+
369
+ /* Model dropdown styling */
370
+ .model-dropdown .select {
371
+ padding: 8px 12px;
372
+ border: 1px solid #ced4da;
373
+ border-radius: 4px;
374
+ }
375
+
376
+ /* Output styling */
377
+ .output-text {
378
+ border: 1px solid #ced4da;
379
+ border-radius: 4px;
380
+ padding: 10px;
381
+ min-height: 150px;
382
+ }
383
+
384
+ /* Add some space between elements */
385
+ .gr-box {
386
+ margin-bottom: 15px;
387
+ }
388
  """
389
 
390
+ # --- Gradio Interface ---
391
+ with gr.Blocks(css=css) as demo:
392
+ gr.Markdown("# **[Multimodal OCR2](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
+ with gr.Row():
395
+ # Left sidebar - OCR section
396
+ with gr.Column(scale=1, min_width=250, elem_classes="sidebar"):
397
+ gr.Markdown("## OCR")
398
+ add_conv_btn = gr.Button("+ Add Conv", elem_classes="add-conv-btn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
+ # Document grid
401
+ gr.Markdown("### Documents")
402
+ with gr.Group(elem_classes="doc-grid"):
403
+ for i in range(5):
404
+ with gr.Column():
405
+ gr.Markdown(f'<div class="doc-item"><div class="doc-icon">📄</div>Doc {i+1}</div>')
406
+
407
+ # Main content area
408
+ with gr.Column(scale=3, elem_classes="content-area"):
409
+ # Document processing section
410
+ with gr.Group():
411
+ gr.Markdown("## Multimodal OCR2")
412
+
413
+ # Document grid (5 document thumbnails as shown in the sketch)
414
+ with gr.Row(elem_classes="doc-grid"):
415
+ for i in range(5):
416
+ with gr.Column():
417
+ doc_item = gr.Image(
418
+ value=None,
419
+ label=f"Document {i+1}",
420
+ height=120,
421
+ show_label=False,
422
+ container=False,
423
+ elem_classes="doc-item"
424
+ )
425
+
426
+ # Examples section
427
+ gr.Markdown("### Examples")
428
+ with gr.Row():
429
+ with gr.Column():
430
+ gr.Examples(
431
+ examples=image_examples,
432
+ inputs=[image_query, image_upload],
433
+ label="Image Examples"
434
+ )
435
+ with gr.Column():
436
+ gr.Examples(
437
+ examples=video_examples,
438
+ inputs=[video_query, video_upload],
439
+ label="Video Examples"
440
+ )
441
+
442
+ # File upload and controls
443
+ with gr.Group(elem_classes="upload-controls"):
444
+ # File upload area
445
+ with gr.Column(elem_classes="file-upload"):
446
+ file_upload = gr.File(
447
+ label="Upload files (image/video)",
448
+ file_types=["image", "video"],
449
+ elem_classes="file-upload"
450
+ )
451
+
452
+ # Model dropdown
453
+ model_dropdown = gr.Dropdown(
454
+ choices=["Nanonets-OCR-s", "MonkeyOCR-Recognition", "Thyme-RL", "Typhoon-OCR-7B", "SmolDocling-256M-preview"],
455
+ value="Nanonets-OCR-s",
456
+ label="Select Model",
457
+ elem_classes="model-dropdown"
458
+ )
459
+
460
+ # Submit button
461
+ submit_btn = gr.Button("→", size="lg", elem_classes="submit-btn")
462
+
463
+ # Advanced options (hidden by default)
464
+ with gr.Accordion("Advanced Options", open=False):
465
+ max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
466
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
467
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
468
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
469
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
470
+
471
+ # Query input
472
  query_input = gr.Textbox(
473
+ label="Enter your query",
474
+ placeholder="Describe the image, extract text, convert to markdown...",
475
+ elem_classes="query-input"
476
  )
 
477
 
478
+ # Output area
479
+ with gr.Group(elem_classes="output-area"):
480
+ gr.Markdown("### Output")
481
+ raw_output = gr.Textbox(
482
+ label="Result",
483
+ interactive=False,
484
+ lines=10,
485
+ elem_classes="output-text"
486
+ )
487
+
488
+ # Initialize state variables
489
+ image_query = gr.State("")
490
+ video_query = gr.State("")
491
+ image_upload = gr.State(None)
492
+ video_upload = gr.State(None)
493
+ media_type = gr.State("image")
494
 
495
  # --- Event Handlers ---
496
+ def handle_file_upload(file):
497
+ if file is None:
498
+ return "image", None, None
499
+ file_path = file.name
500
+ if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
501
+ return "image", Image.open(file_path), None
502
+ elif file_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
503
+ return "video", None, file_path
504
+ return "image", None, None
505
+
506
+ file_upload.change(
507
  fn=handle_file_upload,
508
+ inputs=[file_upload],
509
+ outputs=[media_type, image_upload, video_upload]
510
  )
511
+
512
+ def handle_model_selection(model_name):
513
+ # This function could be used to update the UI based on model selection
514
+ return f"Using {model_name}"
515
+
516
+ model_dropdown.change(
517
+ fn=handle_model_selection,
518
+ inputs=[model_dropdown],
519
+ outputs=[]
520
  )
521
 
522
+ def generate_wrapper(text, img, vid, model, max_tokens, temp, top_p, top_k, rep_penalty, m_type):
523
+ if m_type == "image" and img is not None:
524
+ yield from generate_image_wrapper(text, img, model, max_tokens, temp, top_p, top_k, rep_penalty)
525
+ elif m_type == "video" and vid is not None:
526
+ yield from generate_video_wrapper(text, vid, model, max_tokens, temp, top_p, top_k, rep_penalty)
527
+ else:
528
+ yield "Please upload a valid file", "Please upload a valid file"
529
+
530
+ submit_btn.click(
531
+ fn=generate_wrapper,
532
+ inputs=[
533
+ query_input,
534
+ image_upload,
535
+ video_upload,
536
+ model_dropdown,
537
+ max_new_tokens,
538
+ temperature,
539
+ top_p,
540
+ top_k,
541
+ repetition_penalty,
542
+ media_type
543
+ ],
544
+ outputs=[raw_output, raw_output]
545
  )
546
 
 
547
  if __name__ == "__main__":
548
  demo.queue(max_size=50).launch(share=True, show_error=True)