prithivMLmods commited on
Commit
063e299
·
verified ·
1 Parent(s): be8b851

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +370 -179
app.py CHANGED
@@ -11,13 +11,13 @@ from threading import Thread
11
  from io import BytesIO
12
  import uuid
13
  import tempfile
 
14
 
15
  import gradio as gr
16
- import requests
17
  import torch
18
  from PIL import Image
19
- import fitz
20
- import numpy as np
21
 
22
 
23
  from transformers import (
@@ -26,33 +26,29 @@ from transformers import (
26
  AutoModelForCausalLM,
27
  AutoProcessor,
28
  TextIteratorStreamer,
29
- AutoTokenizer,
30
  )
31
 
32
- from transformers.image_utils import load_image as hf_load_image
33
-
34
- from reportlab.lib.pagesizes import A4
35
- from reportlab.lib.styles import getSampleStyleSheet
36
- from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer
37
- from reportlab.lib.units import inch
38
-
39
  # --- Constants and Model Setup ---
40
  MAX_INPUT_TOKEN_LENGTH = 4096
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
 
43
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
44
  print("torch.__version__ =", torch.__version__)
45
  print("torch.version.cuda =", torch.version.cuda)
46
- print("cuda available:", torch.cuda.is_available())
47
- print("cuda device count:", torch.cuda.device_count())
48
  if torch.cuda.is_available():
49
- print("current device:", torch.cuda.current_device())
50
- print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
51
-
52
  print("Using device:", device)
 
 
53
 
54
  # --- Model Loading ---
 
55
  # Load Camel-Doc-OCR-062825
 
56
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
57
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
58
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -60,8 +56,10 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
  trust_remote_code=True,
61
  torch_dtype=torch.float16
62
  ).to(device).eval()
 
63
 
64
  # MinerU2.5-2509
 
65
  MODEL_ID_T = "opendatalab/MinerU2.5-2509-1.2B"
66
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
67
  model_t = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -69,8 +67,11 @@ model_t = Qwen2VLForConditionalGeneration.from_pretrained(
69
  trust_remote_code=True,
70
  torch_dtype=torch.float16
71
  ).to(device).eval()
 
 
72
 
73
  # Load Video-MTR
 
74
  MODEL_ID_S = "Phoebe13/Video-MTR"
75
  processor_s = AutoProcessor.from_pretrained(MODEL_ID_S, trust_remote_code=True)
76
  model_s = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -78,8 +79,10 @@ model_s = Qwen2_5_VLForConditionalGeneration.from_pretrained(
78
  trust_remote_code=True,
79
  torch_dtype=torch.float16
80
  ).to(device).eval()
 
81
 
82
- # moondream3
 
83
  MODEL_ID_MD3 = "moondream/moondream3-preview"
84
  model_md3 = AutoModelForCausalLM.from_pretrained(
85
  MODEL_ID_MD3,
@@ -87,79 +90,228 @@ model_md3 = AutoModelForCausalLM.from_pretrained(
87
  torch_dtype=torch.bfloat16,
88
  device_map={"": "cuda"},
89
  )
90
- # FIX: Added trust_remote_code=True to resolve the loading error
91
- tokenizer_md3 = AutoTokenizer.from_pretrained(MODEL_ID_MD3, trust_remote_code=True)
92
 
93
 
94
- # --- PDF Generation and Preview Utility Function ---
95
- def generate_and_preview_pdf(image: Image.Image, text_content: str, font_size: int, line_spacing: float, alignment: str, image_size: str):
96
- """
97
- Generates a PDF, saves it, and then creates image previews of its pages.
98
- Returns the path to the PDF and a list of paths to the preview images.
99
- """
100
- if image is None or not text_content or not text_content.strip():
101
- raise gr.Error("Cannot generate PDF. Image or text content is missing.")
102
-
103
- # --- 1. Generate the PDF ---
104
- temp_dir = tempfile.gettempdir()
105
- pdf_filename = os.path.join(temp_dir, f"output_{uuid.uuid4()}.pdf")
106
- doc = SimpleDocTemplate(
107
- pdf_filename,
108
- pagesize=A4,
109
- rightMargin=inch, leftMargin=inch,
110
- topMargin=inch, bottomMargin=inch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
- styles = getSampleStyleSheet()
113
- style_normal = styles["Normal"]
114
- style_normal.fontSize = int(font_size)
115
- style_normal.leading = int(font_size) * line_spacing
116
- style_normal.alignment = {"Left": 0, "Center": 1, "Right": 2, "Justified": 4}[alignment]
 
 
 
 
 
 
 
 
 
117
 
118
- story = []
119
 
120
- img_buffer = BytesIO()
121
- image.save(img_buffer, format='PNG')
122
- img_buffer.seek(0)
 
 
 
 
123
 
124
- page_width, _ = A4
125
- available_width = page_width - 2 * inch
126
- image_widths = {
127
- "Small": available_width * 0.3,
128
- "Medium": available_width * 0.6,
129
- "Large": available_width * 0.9,
130
- }
131
- img_width = image_widths[image_size]
132
- img = RLImage(img_buffer, width=img_width, height=image.height * (img_width / image.width))
133
- story.append(img)
134
- story.append(Spacer(1, 12))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- cleaned_text = re.sub(r'#+\s*', '', text_content).replace("*", "")
137
- text_paragraphs = cleaned_text.split('\n')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- for para in text_paragraphs:
140
- if para.strip():
141
- story.append(Paragraph(para, style_normal))
 
 
 
 
142
 
143
- doc.build(story)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- # --- 2. Render PDF pages as images for preview ---
146
- preview_images = []
147
- try:
148
- pdf_doc = fitz.open(pdf_filename)
149
- for page_num in range(len(pdf_doc)):
150
- page = pdf_doc.load_page(page_num)
151
- pix = page.get_pixmap(dpi=150)
152
- preview_img_path = os.path.join(temp_dir, f"preview_{uuid.uuid4()}_p{page_num}.png")
153
- pix.save(preview_img_path)
154
- preview_images.append(preview_img_path)
155
- pdf_doc.close()
156
- except Exception as e:
157
- print(f"Error generating PDF preview: {e}")
158
-
159
- return pdf_filename, preview_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
 
162
- # --- Core Application Logic ---
163
  @spaces.GPU
164
  def process_document_stream(
165
  model_name: str,
@@ -172,39 +324,27 @@ def process_document_stream(
172
  repetition_penalty: float
173
  ):
174
  """
175
- Main generator function that handles model inference tasks with advanced generation parameters.
176
  """
177
  if image is None:
178
- yield "Please upload an image.", ""
179
  return
180
  if not prompt_input or not prompt_input.strip():
181
- yield "Please enter a prompt.", ""
182
  return
183
 
184
- # --- Special Handling for Moondream3 ---
185
- if model_name == "Moondream3":
186
- # Moondream3 uses a different prompt structure and doesn't stream by default in this implementation
187
- prompt_full = f"<image>\n\nQuestion: {prompt_input}\n\nAnswer:"
188
- answer = model_md3.answer_question(
189
- model_md3.encode_image(image),
190
- prompt_full,
191
- tokenizer=tokenizer_md3
192
- )
193
- yield answer, answer
194
- return
195
-
196
- processor = None
197
- model = None
198
-
199
- # --- Generic Handling for all other models ---
200
- if model_name == "Camel-Doc-OCR-062825": processor, model = processor_m, model_m
201
- elif model_name == "MinerU2.5-2509-1.2B": processor, model = processor_t, model_t
202
- elif model_name == "Video-MTR": processor, model = processor_s, model_s
203
  else:
204
- yield "Invalid model selected.", ""
205
  return
206
 
207
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt_input}]}]
208
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
209
  inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
210
 
@@ -218,7 +358,7 @@ def process_document_stream(
218
  "top_p": top_p,
219
  "top_k": top_k,
220
  "repetition_penalty": repetition_penalty,
221
- "do_sample": True
222
  }
223
 
224
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
@@ -227,12 +367,10 @@ def process_document_stream(
227
  buffer = ""
228
  for new_text in streamer:
229
  buffer += new_text
230
- buffer = buffer.replace("<|im_end|>", "")
 
231
  time.sleep(0.01)
232
- yield buffer , buffer
233
-
234
- yield buffer, buffer
235
-
236
 
237
  # --- Gradio UI Definition ---
238
  def create_gradio_interface():
@@ -241,89 +379,142 @@ def create_gradio_interface():
241
  .main-container { max-width: 1400px; margin: 0 auto; }
242
  .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
243
  .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
244
- #gallery { min-height: 400px; }
245
  """
246
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
247
- gr.HTML("""
248
- <div class="title" style="text-align: center">
249
- <h1>Multimodal VLM v1.0</h1>
250
- <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
251
- Multimodal VLM for Image Content Extraction and Understanding
252
- </p>
253
- </div>
254
- """)
255
-
256
- with gr.Row():
257
- # Left Column (Inputs)
258
- with gr.Column(scale=1):
259
- model_choice = gr.Dropdown(
260
- choices=["Camel-Doc-OCR-062825", "MinerU2.5-2509-1.2B", "Video-MTR", "Moondream3"],
261
- label="Select Model", value= "Camel-Doc-OCR-062825"
262
- )
263
-
264
- prompt_input = gr.Textbox(label="Query Input", placeholder="✦︎ Enter the prompt")
265
- image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
 
 
 
 
 
 
 
 
 
 
266
 
267
- with gr.Accordion("Advanced Settings", open=False):
268
- max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=2048, step=256, label="Max New Tokens")
269
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
270
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
271
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
272
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
273
-
274
- gr.Markdown("### PDF Export Settings")
275
- font_size = gr.Dropdown(choices=["8", "10", "12", "14", "16", "18"], value="12", label="Font Size")
276
- line_spacing = gr.Dropdown(choices=[1.0, 1.15, 1.5, 2.0], value=1.15, label="Line Spacing")
277
- alignment = gr.Dropdown(choices=["Left", "Center", "Right", "Justified"], value="Justified", label="Text Alignment")
278
- image_size = gr.Dropdown(choices=["Small", "Medium", "Large"], value="Medium", label="Image Size in PDF")
279
-
280
- process_btn = gr.Button("🚀 Process Image", variant="primary", elem_classes=["process-button"], size="lg")
281
- clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
282
-
283
- # Right Column (Outputs)
284
- with gr.Column(scale=2):
285
- with gr.Tabs() as tabs:
286
- with gr.Tab("📝 Extracted Content"):
287
- raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=15, show_copy_button=True)
288
  with gr.Row():
289
- examples = gr.Examples(
290
- examples=["examples/1.png", "examples/2.png", "examples/3.png",
291
- "examples/4.png", "examples/5.png", "examples/6.png"],
292
- inputs=image_input, label="Examples"
293
- )
294
- gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/Multimodal-VLM-v1.0/discussions) | [prithivMLmods🤗](https://huggingface.co/prithivMLmods)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
- with gr.Tab("📰 README.md"):
297
- with gr.Accordion("(Result.md)", open=True):
298
- markdown_output = gr.Markdown()
299
-
300
- with gr.Tab("📋 PDF Preview"):
301
- generate_pdf_btn = gr.Button("📄 Generate PDF & Render", variant="primary")
302
- pdf_output_file = gr.File(label="Download Generated PDF", interactive=False)
303
- pdf_preview_gallery = gr.Gallery(label="PDF Page Preview", show_label=True, elem_id="gallery", columns=2, object_fit="contain", height="auto")
304
-
305
- # Event Handlers
306
- def clear_all_outputs():
307
- return None, "", "Raw output will appear here.", "", None, None
308
-
 
 
 
 
 
 
 
 
 
309
  process_btn.click(
310
  fn=process_document_stream,
311
- inputs=[model_choice, image_input, prompt_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
312
- outputs=[raw_output_stream, markdown_output]
313
  )
 
 
 
 
 
 
 
314
 
315
- generate_pdf_btn.click(
316
- fn=generate_and_preview_pdf,
317
- inputs=[image_input, raw_output_stream, font_size, line_spacing, alignment, image_size],
318
- outputs=[pdf_output_file, pdf_preview_gallery]
319
  )
320
-
321
- clear_btn.click(
322
- clear_all_outputs,
323
- outputs=[image_input, prompt_input, raw_output_stream, markdown_output, pdf_output_file, pdf_preview_gallery]
324
  )
 
325
  return demo
326
 
327
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
328
  demo = create_gradio_interface()
329
- demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)
 
11
  from io import BytesIO
12
  import uuid
13
  import tempfile
14
+ import cv2
15
 
16
  import gradio as gr
17
+ import numpy as np
18
  import torch
19
  from PIL import Image
20
+ import supervision as sv
 
21
 
22
 
23
  from transformers import (
 
26
  AutoModelForCausalLM,
27
  AutoProcessor,
28
  TextIteratorStreamer,
 
29
  )
30
 
 
 
 
 
 
 
 
31
  # --- Constants and Model Setup ---
32
  MAX_INPUT_TOKEN_LENGTH = 4096
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
 
35
+ print("--- System Information ---")
36
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
37
  print("torch.__version__ =", torch.__version__)
38
  print("torch.version.cuda =", torch.version.cuda)
39
+ print("CUDA available:", torch.cuda.is_available())
40
+ print("CUDA device count:", torch.cuda.device_count())
41
  if torch.cuda.is_available():
42
+ print("Current device:", torch.cuda.current_device())
43
+ print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
 
44
  print("Using device:", device)
45
+ print("--------------------------")
46
+
47
 
48
  # --- Model Loading ---
49
+
50
  # Load Camel-Doc-OCR-062825
51
+ print("Loading Camel-Doc-OCR-062825...")
52
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
53
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
54
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
56
  trust_remote_code=True,
57
  torch_dtype=torch.float16
58
  ).to(device).eval()
59
+ print("Camel-Doc-OCR-062825 loaded.")
60
 
61
  # MinerU2.5-2509
62
+ print("Loading MinerU2.5-2509...")
63
  MODEL_ID_T = "opendatalab/MinerU2.5-2509-1.2B"
64
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
65
  model_t = Qwen2VLForConditionalGeneration.from_pretrained(
 
67
  trust_remote_code=True,
68
  torch_dtype=torch.float16
69
  ).to(device).eval()
70
+ print("MinerU2.5-2509 loaded.")
71
+
72
 
73
  # Load Video-MTR
74
+ print("Loading Video-MTR...")
75
  MODEL_ID_S = "Phoebe13/Video-MTR"
76
  processor_s = AutoProcessor.from_pretrained(MODEL_ID_S, trust_remote_code=True)
77
  model_s = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
79
  trust_remote_code=True,
80
  torch_dtype=torch.float16
81
  ).to(device).eval()
82
+ print("Video-MTR loaded.")
83
 
84
+ # Load moondream3
85
+ print("Loading moondream3-preview...")
86
  MODEL_ID_MD3 = "moondream/moondream3-preview"
87
  model_md3 = AutoModelForCausalLM.from_pretrained(
88
  MODEL_ID_MD3,
 
90
  torch_dtype=torch.bfloat16,
91
  device_map={"": "cuda"},
92
  )
93
+ model_md3.compile()
94
+ print("moondream3-preview loaded and compiled.")
95
 
96
 
97
+ # --- Moondream3 Utility Functions ---
98
+
99
+ def create_annotated_image(image, detection_result, object_name="Object"):
100
+ if not isinstance(detection_result, dict) or "objects" not in detection_result:
101
+ return image
102
+
103
+ original_width, original_height = image.size
104
+ annotated_image = np.array(image.convert("RGB"))
105
+
106
+ bboxes = []
107
+ labels = []
108
+
109
+ for i, obj in enumerate(detection_result["objects"]):
110
+ x_min = int(obj["x_min"] * original_width)
111
+ y_min = int(obj["y_min"] * original_height)
112
+ x_max = int(obj["x_max"] * original_width)
113
+ y_max = int(obj["y_max"] * original_height)
114
+
115
+ x_min = max(0, min(x_min, original_width))
116
+ y_min = max(0, min(y_min, original_height))
117
+ x_max = max(0, min(x_max, original_width))
118
+ y_max = max(0, min(y_max, original_height))
119
+
120
+ if x_max > x_min and y_max > y_min:
121
+ bboxes.append([x_min, y_min, x_max, y_max])
122
+ labels.append(f"{object_name} {i+1}")
123
+
124
+ if not bboxes:
125
+ return image
126
+
127
+ detections = sv.Detections(
128
+ xyxy=np.array(bboxes, dtype=np.float32),
129
+ class_id=np.arange(len(bboxes))
130
+ )
131
+
132
+ bounding_box_annotator = sv.BoxAnnotator(
133
+ thickness=3,
134
+ color_lookup=sv.ColorLookup.INDEX
135
  )
136
+ label_annotator = sv.LabelAnnotator(
137
+ text_thickness=2,
138
+ text_scale=0.6,
139
+ color_lookup=sv.ColorLookup.INDEX
140
+ )
141
+
142
+ annotated_image = bounding_box_annotator.annotate(
143
+ scene=annotated_image, detections=detections
144
+ )
145
+ annotated_image = label_annotator.annotate(
146
+ scene=annotated_image, detections=detections, labels=labels
147
+ )
148
+
149
+ return Image.fromarray(annotated_image)
150
 
 
151
 
152
+ @spaces.GPU()
153
+ def process_video_with_tracking(video_path, prompt, detection_interval=3):
154
+ cap = cv2.VideoCapture(video_path)
155
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
156
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
157
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
158
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
159
 
160
+ byte_tracker = sv.ByteTrack()
161
+
162
+ temp_dir = tempfile.mkdtemp()
163
+ output_path = os.path.join(temp_dir, "tracked_video.mp4")
164
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
165
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
166
+
167
+ frame_count = 0
168
+ detection_count = 0
169
+
170
+ try:
171
+ while True:
172
+ ret, frame = cap.read()
173
+ if not ret:
174
+ break
175
+
176
+ run_detection = (frame_count % detection_interval == 0)
177
+ detections = sv.Detections.empty()
178
+
179
+ if run_detection:
180
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
181
+ pil_image = Image.fromarray(frame_rgb)
182
+
183
+ result = model_md3.detect(pil_image, prompt)
184
+ detection_count += 1
185
+
186
+ if "objects" in result and result["objects"]:
187
+ bboxes = []
188
+ confidences = []
189
+
190
+ for obj in result["objects"]:
191
+ x_min = max(0.0, min(1.0, obj["x_min"])) * width
192
+ y_min = max(0.0, min(1.0, obj["y_min"])) * height
193
+ x_max = max(0.0, min(1.0, obj["x_max"])) * width
194
+ y_max = max(0.0, min(1.0, obj["y_max"])) * height
195
+
196
+ if x_max > x_min and y_max > y_min:
197
+ bboxes.append([x_min, y_min, x_max, y_max])
198
+ confidences.append(0.8)
199
+
200
+ if bboxes:
201
+ detections = sv.Detections(
202
+ xyxy=np.array(bboxes, dtype=np.float32),
203
+ confidence=np.array(confidences, dtype=np.float32),
204
+ class_id=np.zeros(len(bboxes), dtype=int)
205
+ )
206
+
207
+ detections = byte_tracker.update_with_detections(detections)
208
 
209
+ if len(detections) > 0:
210
+ box_annotator = sv.BoxAnnotator(thickness=3, color_lookup=sv.ColorLookup.TRACK)
211
+ label_annotator = sv.LabelAnnotator(text_scale=0.6, text_thickness=2, color_lookup=sv.ColorLookup.TRACK)
212
+
213
+ labels = [f"{prompt} ID: {tracker_id}" for tracker_id in detections.tracker_id]
214
+
215
+ frame = box_annotator.annotate(scene=frame, detections=detections)
216
+ frame = label_annotator.annotate(scene=frame, detections=detections, labels=labels)
217
+
218
+ out.write(frame)
219
+ frame_count += 1
220
+
221
+ if frame_count % 30 == 0:
222
+ progress = (frame_count / total_frames) * 100
223
+ print(f"Processing: {progress:.1f}% ({frame_count}/{total_frames}) - Detections: {detection_count}")
224
+
225
+ finally:
226
+ cap.release()
227
+ out.release()
228
 
229
+ summary = f"""Video processing complete:
230
+ - Total frames processed: {frame_count}
231
+ - Detection runs: {detection_count} (every {detection_interval} frames)
232
+ - Objects tracked: {prompt}
233
+ - Processing speed: ~{detection_count/frame_count*100:.1f}% detection rate for optimization"""
234
+
235
+ return output_path, summary
236
 
237
+ def create_point_annotated_image(image, point_result):
238
+ if not isinstance(point_result, dict) or "points" not in point_result:
239
+ return image
240
+
241
+ original_width, original_height = image.size
242
+ annotated_image = np.array(image.convert("RGB"))
243
+
244
+ points = []
245
+ for point in point_result["points"]:
246
+ x = int(point["x"] * original_width)
247
+ y = int(point["y"] * original_height)
248
+ points.append([x, y])
249
+
250
+ if points:
251
+ points_array = np.array(points).reshape(1, -1, 2)
252
+ key_points = sv.KeyPoints(xy=points_array)
253
+ vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
254
+ annotated_image = vertex_annotator.annotate(
255
+ scene=annotated_image, key_points=key_points
256
+ )
257
+
258
+ return Image.fromarray(annotated_image)
259
 
260
+ @spaces.GPU()
261
+ def detect_objects_md3(image, prompt, task_type, max_objects):
262
+ STANDARD_SIZE = (1024, 1024)
263
+ if image is None:
264
+ raise gr.Error("Please upload an image.")
265
+ image.thumbnail(STANDARD_SIZE)
266
+
267
+ t0 = time.perf_counter()
268
+
269
+ if task_type == "Object Detection":
270
+ settings = {"max_objects": max_objects} if max_objects > 0 else {}
271
+ result = model_md3.detect(image, prompt, settings=settings)
272
+ annotated_image = create_annotated_image(image, result, prompt)
273
+ elif task_type == "Point Detection":
274
+ result = model_md3.point(image, prompt)
275
+ annotated_image = create_point_annotated_image(image, result)
276
+ elif task_type == "Caption":
277
+ result = model_md3.caption(image, length="normal")
278
+ annotated_image = image
279
+ else:
280
+ result = model_md3.query(image=image, question=prompt, reasoning=True)
281
+ annotated_image = image
282
+
283
+ elapsed_ms = (time.perf_counter() - t0) * 1_000
284
+
285
+ if isinstance(result, dict):
286
+ if "objects" in result:
287
+ output_text = f"Found {len(result['objects'])} objects:\n"
288
+ for i, obj in enumerate(result['objects'], 1):
289
+ output_text += f"\n{i}. Bounding box: ({obj['x_min']:.3f}, {obj['y_min']:.3f}, {obj['x_max']:.3f}, {obj['y_max']:.3f})"
290
+ elif "points" in result:
291
+ output_text = f"Found {len(result['points'])} points:\n"
292
+ for i, point in enumerate(result['points'], 1):
293
+ output_text += f"\n{i}. Point: ({point['x']:.3f}, {point['y']:.3f})"
294
+ elif "caption" in result:
295
+ output_text = result['caption']
296
+ elif "answer" in result:
297
+ output_text = f"Reasoning: {result.get('reasoning', 'N/A')}\n\nAnswer: {result['answer']}"
298
+ else:
299
+ output_text = json.dumps(result, indent=2)
300
+ else:
301
+ output_text = str(result)
302
+
303
+ timing_text = f"Inference time: {elapsed_ms:.0f} ms"
304
+
305
+ return annotated_image, output_text, timing_text
306
+
307
+ def process_video_md3(video_file, prompt, detection_interval):
308
+ if video_file is None:
309
+ return None, "Please upload a video file"
310
+ output_path, summary = process_video_with_tracking(video_file, prompt, detection_interval)
311
+ return output_path, summary
312
 
313
 
314
+ # --- Core Application Logic (for other models) ---
315
  @spaces.GPU
316
  def process_document_stream(
317
  model_name: str,
 
324
  repetition_penalty: float
325
  ):
326
  """
327
+ Main generator function for models other than Moondream3.
328
  """
329
  if image is None:
330
+ yield "Please upload an image."
331
  return
332
  if not prompt_input or not prompt_input.strip():
333
+ yield "Please enter a prompt."
334
  return
335
 
336
+ # Select processor and model based on dropdown choice
337
+ if model_name == "Camel-Doc-OCR-062825 (OCR)":
338
+ processor, model = processor_m, model_m
339
+ elif model_name == "MinerU2.5-2509 (General)":
340
+ processor, model = processor_t, model_t
341
+ elif model_name == "Video-MTR (Video/Text)":
342
+ processor, model = processor_s, model_s
 
 
 
 
 
 
 
 
 
 
 
 
343
  else:
344
+ yield "Invalid model selected."
345
  return
346
 
347
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_input}]}]
348
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
349
  inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
350
 
 
358
  "top_p": top_p,
359
  "top_k": top_k,
360
  "repetition_penalty": repetition_penalty,
361
+ "do_sample": True if temperature > 0 else False
362
  }
363
 
364
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
367
  buffer = ""
368
  for new_text in streamer:
369
  buffer += new_text
370
+ # Clean up potential model-specific tokens
371
+ buffer = buffer.replace("<|im_end|>", "").replace("</s>", "")
372
  time.sleep(0.01)
373
+ yield buffer
 
 
 
374
 
375
  # --- Gradio UI Definition ---
376
  def create_gradio_interface():
 
379
  .main-container { max-width: 1400px; margin: 0 auto; }
380
  .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
381
  .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
 
382
  """
383
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
384
+ gr.Markdown("# Multimodal VLM v1.0 🚀")
385
+ gr.Markdown("Explore the capabilities of various Vision Language Models for tasks like OCR, VQA, Object Detection, and Video Tracking.")
386
+
387
+ with gr.Tabs():
388
+ # --- TAB 1: Document and General VLMs ---
389
+ with gr.TabItem("📄 Document & General VLM"):
390
+ with gr.Row():
391
+ with gr.Column(scale=1):
392
+ gr.Markdown("### 1. Configure Inputs")
393
+ model_choice = gr.Dropdown(
394
+ choices=["Camel-Doc-OCR-062825 (OCR)", "MinerU2.5-2509 (General)", "Video-MTR (Video/Text)"],
395
+ label="Select Model", value= "Camel-Doc-OCR-062825 (OCR)"
396
+ )
397
+ image_input_doc = gr.Image(label="Upload Image", type="pil", sources=['upload'])
398
+ prompt_input_doc = gr.Textbox(label="Query Input", placeholder="e.g., 'Transcribe the text in this document.'")
399
+
400
+ with gr.Accordion("Advanced Generation Settings", open=False):
401
+ max_new_tokens = gr.Slider(minimum=256, maximum=4096, value=2048, step=128, label="Max New Tokens")
402
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7)
403
+ top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9)
404
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=40)
405
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
406
+
407
+ process_btn = gr.Button("🚀 Process Image", variant="primary", elem_classes=["process-button"])
408
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
409
+
410
+ with gr.Column(scale=2):
411
+ gr.Markdown("### 2. View Output")
412
+ output_stream = gr.Textbox(label="Model Output", interactive=False, lines=20, show_copy_button=True)
413
 
414
+ gr.Examples(
415
+ examples=[
416
+ ["examples/1.png", "Transcribe this receipt."],
417
+ ["examples/2.png", "Extract the table from this document as markdown."],
418
+ ["examples/3.png", "What information is presented in this infographic?"],
419
+ ],
420
+ inputs=[image_input_doc, prompt_input_doc]
421
+ )
422
+
423
+ # --- TAB 2: Moondream3 Lab ---
424
+ with gr.TabItem("🌝 Moondream3 Lab"):
425
+ with gr.Tabs():
426
+ with gr.TabItem("🖼️ Image Processing"):
 
 
 
 
 
 
 
 
427
  with gr.Row():
428
+ with gr.Column(scale=1):
429
+ md3_image_input = gr.Image(label="Upload an image", type="pil", height=400)
430
+ md3_task_type = gr.Radio(
431
+ choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
432
+ label="Task Type", value="Object Detection"
433
+ )
434
+ md3_prompt_input = gr.Textbox(
435
+ label="Prompt (object to detect/question to ask)",
436
+ placeholder="e.g., 'car', 'person', 'What's in this image?'", value="objects"
437
+ )
438
+ md3_max_objects = gr.Number(
439
+ label="Max Objects (for Object Detection only)",
440
+ value=10, minimum=1, maximum=50, step=1, visible=True
441
+ )
442
+ md3_generate_btn = gr.Button(value="✨ Generate", variant="primary")
443
+ with gr.Column(scale=1):
444
+ md3_output_image = gr.Image(type="pil", label="Result", height=400)
445
+ md3_output_textbox = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
446
+ md3_output_time = gr.Markdown()
447
+
448
+ gr.Examples(
449
+ examples=[
450
+ ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG", "Object Detection", "candy", 5],
451
+ ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG", "Point Detection", "candy", 5],
452
+ ["https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg", "Caption", "", 5],
453
+ ["https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg", "Visual Question Answering", "how well does moondream 3 perform in chartvqa?", 5],
454
+ ],
455
+ inputs=[md3_image_input, md3_task_type, md3_prompt_input, md3_max_objects],
456
+ label="Click an example to populate inputs"
457
+ )
458
 
459
+ with gr.TabItem("📹 Video Object Tracking"):
460
+ with gr.Row():
461
+ with gr.Column(scale=1):
462
+ md3_video_input = gr.Video(label="Upload a video file", height=400)
463
+ md3_video_prompt = gr.Textbox(label="Object to track", placeholder="e.g., 'person', 'car', 'ball'", value="person")
464
+ md3_detection_interval = gr.Slider(
465
+ minimum=5, maximum=30, value=15, step=1, label="Detection Interval (frames)",
466
+ info="Run detection every N frames (lower is slower but more accurate)."
467
+ )
468
+ md3_process_video_btn = gr.Button(value="🎥 Process Video", variant="primary")
469
+ with gr.Column(scale=1):
470
+ md3_output_video = gr.Video(label="Tracked Video Result", height=400)
471
+ md3_video_summary = gr.Textbox(label="Processing Summary", lines=8, show_copy_button=True)
472
+ gr.Examples(
473
+ examples=[["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4", "snowboarder", 15]],
474
+ inputs=[md3_video_input, md3_video_prompt, md3_detection_interval],
475
+ label="Click an example to populate inputs"
476
+ )
477
+
478
+ # --- Event Handlers ---
479
+
480
+ # Document Tab
481
  process_btn.click(
482
  fn=process_document_stream,
483
+ inputs=[model_choice, image_input_doc, prompt_input_doc, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
484
+ outputs=[output_stream]
485
  )
486
+ clear_btn.click(lambda: (None, "", ""), outputs=[image_input_doc, prompt_input_doc, output_stream])
487
+
488
+ # Moondream3 Tab
489
+ def update_max_objects_visibility(task):
490
+ return gr.update(visible=(task == "Object Detection"))
491
+
492
+ md3_task_type.change(fn=update_max_objects_visibility, inputs=[md3_task_type], outputs=[md3_max_objects])
493
 
494
+ md3_generate_btn.click(
495
+ fn=detect_objects_md3,
496
+ inputs=[md3_image_input, md3_prompt_input, md3_task_type, md3_max_objects],
497
+ outputs=[md3_output_image, md3_output_textbox, md3_output_time]
498
  )
499
+ md3_process_video_btn.click(
500
+ fn=process_video_md3,
501
+ inputs=[md3_video_input, md3_video_prompt, md3_detection_interval],
502
+ outputs=[md3_output_video, md3_video_summary]
503
  )
504
+
505
  return demo
506
 
507
  if __name__ == "__main__":
508
+ # Create some example images if they don't exist
509
+ if not os.path.exists("examples"):
510
+ os.makedirs("examples")
511
+ try:
512
+ # Dummy image creation for examples to prevent errors if not present
513
+ Image.new('RGB', (200, 100), color = 'red').save('examples/1.png')
514
+ Image.new('RGB', (200, 100), color = 'green').save('examples/2.png')
515
+ Image.new('RGB', (200, 100), color = 'blue').save('examples/3.png')
516
+ except Exception as e:
517
+ print(f"Could not create dummy example images: {e}")
518
+
519
  demo = create_gradio_interface()
520
+ demo.queue(max_size=20).launch(share=True, show_error=True)