prithivMLmods commited on
Commit
cec9320
·
verified ·
1 Parent(s): a7fe8f1

upload app

Browse files
Files changed (1) hide show
  1. app.py +380 -0
app.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import math
4
+ import os
5
+ import traceback
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ import re
9
+ import time
10
+ from threading import Thread
11
+ from io import BytesIO
12
+ import uuid
13
+ import tempfile
14
+
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ from PIL import Image
19
+ import supervision as sv
20
+
21
+
22
+ from transformers import (
23
+ Qwen2_5_VLForConditionalGeneration,
24
+ Glm4vForConditionalGeneration,
25
+ Qwen2VLForConditionalGeneration,
26
+ AutoModelForCausalLM,
27
+ AutoProcessor,
28
+ TextIteratorStreamer,
29
+ )
30
+
31
+ # --- Constants and Model Setup ---
32
+ MAX_INPUT_TOKEN_LENGTH = 4096
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ print("--- System Information ---")
36
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
37
+ print("torch.__version__ =", torch.__version__)
38
+ print("torch.version.cuda =", torch.version.cuda)
39
+ print("CUDA available:", torch.cuda.is_available())
40
+ print("CUDA device count:", torch.cuda.device_count())
41
+ if torch.cuda.is_available():
42
+ print("Current device:", torch.cuda.current_device())
43
+ print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
44
+ print("Using device:", device)
45
+ print("--------------------------")
46
+
47
+
48
+ # --- Model Loading ---
49
+
50
+ # Load Camel-Doc-OCR-062825
51
+ print("Loading Camel-Doc-OCR-062825...")
52
+ MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
53
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
54
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
55
+ MODEL_ID_M,
56
+ trust_remote_code=True,
57
+ torch_dtype=torch.float16
58
+ ).to(device).eval()
59
+ print("Camel-Doc-OCR-062825 loaded.")
60
+
61
+ # GLM-4.1V-9B-Thinking
62
+ print("Loading GLM-4.1V-9B-Thinking")
63
+ MODEL_ID_T = "zai-org/GLM-4.1V-9B-Thinking"
64
+ processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
65
+ model_t = Glm4vForConditionalGeneration.from_pretrained(
66
+ MODEL_ID_T,
67
+ trust_remote_code=True,
68
+ torch_dtype=torch.float16
69
+ ).to(device).eval()
70
+ print("GLM-4.1V-9B-Thinking loaded.")
71
+
72
+ # Load moondream3
73
+ print("Loading moondream3-preview...")
74
+ MODEL_ID_MD3 = "moondream/moondream3-preview"
75
+ model_md3 = AutoModelForCausalLM.from_pretrained(
76
+ MODEL_ID_MD3,
77
+ trust_remote_code=True,
78
+ torch_dtype=torch.bfloat16,
79
+ device_map={"": "cuda"},
80
+ )
81
+ model_md3.compile()
82
+ print("moondream3-preview loaded and compiled.")
83
+
84
+
85
+ # --- Moondream3 Utility Functions ---
86
+
87
+ def create_annotated_image(image, detection_result, object_name="Object"):
88
+ if not isinstance(detection_result, dict) or "objects" not in detection_result:
89
+ return image
90
+
91
+ original_width, original_height = image.size
92
+ annotated_image = np.array(image.convert("RGB"))
93
+
94
+ bboxes = []
95
+ labels = []
96
+
97
+ for i, obj in enumerate(detection_result["objects"]):
98
+ x_min = int(obj["x_min"] * original_width)
99
+ y_min = int(obj["y_min"] * original_height)
100
+ x_max = int(obj["x_max"] * original_width)
101
+ y_max = int(obj["y_max"] * original_height)
102
+
103
+ x_min = max(0, min(x_min, original_width))
104
+ y_min = max(0, min(y_min, original_height))
105
+ x_max = max(0, min(x_max, original_width))
106
+ y_max = max(0, min(y_max, original_height))
107
+
108
+ if x_max > x_min and y_max > y_min:
109
+ bboxes.append([x_min, y_min, x_max, y_max])
110
+ labels.append(f"{object_name} {i+1}")
111
+
112
+ if not bboxes:
113
+ return image
114
+
115
+ detections = sv.Detections(
116
+ xyxy=np.array(bboxes, dtype=np.float32),
117
+ class_id=np.arange(len(bboxes))
118
+ )
119
+
120
+ bounding_box_annotator = sv.BoxAnnotator(
121
+ thickness=3,
122
+ color_lookup=sv.ColorLookup.INDEX
123
+ )
124
+ label_annotator = sv.LabelAnnotator(
125
+ text_thickness=2,
126
+ text_scale=0.6,
127
+ color_lookup=sv.ColorLookup.INDEX
128
+ )
129
+
130
+ annotated_image = bounding_box_annotator.annotate(
131
+ scene=annotated_image, detections=detections
132
+ )
133
+ annotated_image = label_annotator.annotate(
134
+ scene=annotated_image, detections=detections, labels=labels
135
+ )
136
+
137
+ return Image.fromarray(annotated_image)
138
+
139
+ def create_point_annotated_image(image, point_result):
140
+ if not isinstance(point_result, dict) or "points" not in point_result:
141
+ return image
142
+
143
+ original_width, original_height = image.size
144
+ annotated_image = np.array(image.convert("RGB"))
145
+
146
+ points = []
147
+ for point in point_result["points"]:
148
+ x = int(point["x"] * original_width)
149
+ y = int(point["y"] * original_height)
150
+ points.append([x, y])
151
+
152
+ if points:
153
+ points_array = np.array(points).reshape(1, -1, 2)
154
+ key_points = sv.KeyPoints(xy=points_array)
155
+ vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
156
+ annotated_image = vertex_annotator.annotate(
157
+ scene=annotated_image, key_points=key_points
158
+ )
159
+
160
+ return Image.fromarray(annotated_image)
161
+
162
+ @spaces.GPU()
163
+ def detect_objects_md3(image, prompt, task_type, max_objects):
164
+ STANDARD_SIZE = (1024, 1024)
165
+ if image is None:
166
+ raise gr.Error("Please upload an image.")
167
+ image.thumbnail(STANDARD_SIZE)
168
+
169
+ t0 = time.perf_counter()
170
+
171
+ if task_type == "Object Detection":
172
+ settings = {"max_objects": max_objects} if max_objects > 0 else {}
173
+ result = model_md3.detect(image, prompt, settings=settings)
174
+ annotated_image = create_annotated_image(image, result, prompt)
175
+ elif task_type == "Point Detection":
176
+ result = model_md3.point(image, prompt)
177
+ annotated_image = create_point_annotated_image(image, result)
178
+ elif task_type == "Caption":
179
+ result = model_md3.caption(image, length="normal")
180
+ annotated_image = image
181
+ else:
182
+ result = model_md3.query(image=image, question=prompt, reasoning=True)
183
+ annotated_image = image
184
+
185
+ elapsed_ms = (time.perf_counter() - t0) * 1_000
186
+
187
+ if isinstance(result, dict):
188
+ if "objects" in result:
189
+ output_text = f"Found {len(result['objects'])} objects:\n"
190
+ for i, obj in enumerate(result['objects'], 1):
191
+ output_text += f"\n{i}. Bounding box: ({obj['x_min']:.3f}, {obj['y_min']:.3f}, {obj['x_max']:.3f}, {obj['y_max']:.3f})"
192
+ elif "points" in result:
193
+ output_text = f"Found {len(result['points'])} points:\n"
194
+ for i, point in enumerate(result['points'], 1):
195
+ output_text += f"\n{i}. Point: ({point['x']:.3f}, {point['y']:.3f})"
196
+ elif "caption" in result:
197
+ output_text = result['caption']
198
+ elif "answer" in result:
199
+ output_text = f"Reasoning: {result.get('reasoning', 'N/A')}\n\nAnswer: {result['answer']}"
200
+ else:
201
+ output_text = json.dumps(result, indent=2)
202
+ else:
203
+ output_text = str(result)
204
+
205
+ timing_text = f"Inference time: {elapsed_ms:.0f} ms"
206
+
207
+ return annotated_image, output_text, timing_text
208
+
209
+
210
+ # --- Core Application Logic (for other models) ---
211
+ @spaces.GPU
212
+ def process_document_stream(
213
+ model_name: str,
214
+ image: Image.Image,
215
+ prompt_input: str,
216
+ max_new_tokens: int,
217
+ temperature: float,
218
+ top_p: float,
219
+ top_k: int,
220
+ repetition_penalty: float
221
+ ):
222
+ """
223
+ Main generator function for models other than Moondream3.
224
+ """
225
+ if image is None:
226
+ yield "Please upload an image."
227
+ return
228
+ if not prompt_input or not prompt_input.strip():
229
+ yield "Please enter a prompt."
230
+ return
231
+
232
+ # Select processor and model based on dropdown choice
233
+ if model_name == "Camel-Doc-OCR-062825 (OCR)":
234
+ processor, model = processor_m, model_m
235
+ elif model_name == "GLM-4.1V-9B (Thinking)":
236
+ processor, model = processor_t, model_t
237
+ else:
238
+ yield "Invalid model selected."
239
+ return
240
+
241
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_input}]}]
242
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
243
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
244
+
245
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
246
+
247
+ generation_kwargs = {
248
+ **inputs,
249
+ "streamer": streamer,
250
+ "max_new_tokens": max_new_tokens,
251
+ "temperature": temperature,
252
+ "top_p": top_p,
253
+ "top_k": top_k,
254
+ "repetition_penalty": repetition_penalty,
255
+ "do_sample": True if temperature > 0 else False
256
+ }
257
+
258
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
259
+ thread.start()
260
+
261
+ buffer = ""
262
+ for new_text in streamer:
263
+ buffer += new_text
264
+ # Clean up potential model-specific tokens
265
+ buffer = buffer.replace("<|im_end|>", "").replace("</s>", "")
266
+ time.sleep(0.01)
267
+ yield buffer
268
+
269
+ # --- Gradio UI Definition ---
270
+ def create_gradio_interface():
271
+ """Builds and returns the Gradio web interface."""
272
+ css = """
273
+ .main-container { max-width: 1400px; margin: 0 auto; }
274
+ .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
275
+ .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
276
+ .processr-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
277
+ .processr-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
278
+ """
279
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
280
+ gr.Markdown("# Multimodal VLM v1.0 ⚡")
281
+ gr.Markdown("Explore the capabilities of various Vision Language Models for tasks like OCR, VQA, and Object Detection.")
282
+
283
+ with gr.Tabs():
284
+ # --- TAB 1: Document and General VLMs ---
285
+ with gr.TabItem("📄 Document & General VLM"):
286
+ with gr.Row():
287
+ with gr.Column(scale=1):
288
+ #gr.Markdown("### 1. Configure Inputs")
289
+ model_choice = gr.Dropdown(
290
+ choices=["Camel-Doc-OCR-062825 (OCR)", "GLM-4.1V-9B (Thinking)"],
291
+ label="Select Model", value= "Camel-Doc-OCR-062825 (OCR)"
292
+ )
293
+ image_input_doc = gr.Image(label="Upload Image", type="pil", sources=['upload'], height=280)
294
+ prompt_input_doc = gr.Textbox(label="Query Input", placeholder="e.g., 'Transcribe the text in this document.'")
295
+
296
+ with gr.Accordion("Advanced Generation Settings", open=False):
297
+ max_new_tokens = gr.Slider(minimum=256, maximum=4096, value=2048, step=128, label="Max New Tokens")
298
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7)
299
+ top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9)
300
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=40)
301
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
302
+
303
+ process_btn = gr.Button("🚀 Process", variant="primary", elem_classes=["process-button"])
304
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
305
+
306
+ with gr.Column(scale=2):
307
+ #gr.Markdown("### 2. View Output")
308
+ with gr.Tab("Output Stream"):
309
+ output_stream = gr.Textbox(label="Model Output", interactive=False, lines=24, show_copy_button=True)
310
+
311
+ gr.Examples(
312
+ examples=[
313
+ ["examples/1.jpg", "Transcribe this receipt."],
314
+ ["examples/2.jpg", "Extract the content."],
315
+ ["examples/3.jpg", "OCR the image."],
316
+ ],
317
+ inputs=[image_input_doc, prompt_input_doc]
318
+ )
319
+
320
+ # --- TAB 2: Moondream3 Lab ---
321
+ with gr.TabItem("🌝 Moondream3"):
322
+ with gr.Row():
323
+ with gr.Column(scale=1):
324
+ md3_image_input = gr.Image(label="Upload an image", type="pil", height=400)
325
+ md3_task_type = gr.Radio(
326
+ choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
327
+ label="Task Type", value="Object Detection"
328
+ )
329
+ md3_prompt_input = gr.Textbox(
330
+ label="Prompt (object to detect/question to ask)",
331
+ placeholder="e.g., 'car', 'person', 'What's in this image?'"
332
+ )
333
+ md3_max_objects = gr.Number(
334
+ label="Max Objects (for Object Detection only)",
335
+ value=10, minimum=1, maximum=50, step=1, visible=True
336
+ )
337
+ md3_generate_btn = gr.Button(value="🚀 Process", variant="primary", elem_classes=["processr-button"])
338
+ with gr.Column(scale=1):
339
+ md3_output_image = gr.Image(type="pil", label="Result", height=400)
340
+ md3_output_textbox = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
341
+ md3_output_time = gr.Markdown()
342
+
343
+ gr.Examples(
344
+ examples=[
345
+ ["md3/1.jpg", "Object Detection", "boats", 7],
346
+ ["md3/2.jpg", "Point Detection", "candy", 7],
347
+ ["md3/3.png", "Caption", "", 5],
348
+ ["md3/4.jpeg", "Visual Question Answering", "Analyze the GDP trend over the yearsAnalyze the GDP trend over the years.", 5],
349
+ ],
350
+ inputs=[md3_image_input, md3_task_type, md3_prompt_input, md3_max_objects],
351
+ label="Click an example to populate inputs"
352
+ )
353
+
354
+ # --- Event Handlers ---
355
+
356
+ # Document Tab
357
+ process_btn.click(
358
+ fn=process_document_stream,
359
+ inputs=[model_choice, image_input_doc, prompt_input_doc, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
360
+ outputs=[output_stream]
361
+ )
362
+ clear_btn.click(lambda: (None, "", ""), outputs=[image_input_doc, prompt_input_doc, output_stream])
363
+
364
+ # Moondream3 Tab
365
+ def update_max_objects_visibility(task):
366
+ return gr.update(visible=(task == "Object Detection"))
367
+
368
+ md3_task_type.change(fn=update_max_objects_visibility, inputs=[md3_task_type], outputs=[md3_max_objects])
369
+
370
+ md3_generate_btn.click(
371
+ fn=detect_objects_md3,
372
+ inputs=[md3_image_input, md3_prompt_input, md3_task_type, md3_max_objects],
373
+ outputs=[md3_output_image, md3_output_textbox, md3_output_time]
374
+ )
375
+
376
+ return demo
377
+
378
+ if __name__ == "__main__":
379
+ demo = create_gradio_interface()
380
+ demo.queue(max_size=50).launch(share=True, show_error=True)