prithivMLmods commited on
Commit
a7fe8f1
·
verified ·
1 Parent(s): d417121

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -380
app.py DELETED
@@ -1,380 +0,0 @@
1
- import spaces
2
- import json
3
- import math
4
- import os
5
- import traceback
6
- from io import BytesIO
7
- from typing import Any, Dict, List, Optional, Tuple
8
- import re
9
- import time
10
- from threading import Thread
11
- from io import BytesIO
12
- import uuid
13
- import tempfile
14
-
15
- import gradio as gr
16
- import numpy as np
17
- import torch
18
- from PIL import Image
19
- import supervision as sv
20
-
21
-
22
- from transformers import (
23
- Qwen2_5_VLForConditionalGeneration,
24
- Glm4vForConditionalGeneration,
25
- Qwen2VLForConditionalGeneration,
26
- AutoModelForCausalLM,
27
- AutoProcessor,
28
- TextIteratorStreamer,
29
- )
30
-
31
- # --- Constants and Model Setup ---
32
- MAX_INPUT_TOKEN_LENGTH = 4096
33
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
-
35
- print("--- System Information ---")
36
- print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
37
- print("torch.__version__ =", torch.__version__)
38
- print("torch.version.cuda =", torch.version.cuda)
39
- print("CUDA available:", torch.cuda.is_available())
40
- print("CUDA device count:", torch.cuda.device_count())
41
- if torch.cuda.is_available():
42
- print("Current device:", torch.cuda.current_device())
43
- print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
44
- print("Using device:", device)
45
- print("--------------------------")
46
-
47
-
48
- # --- Model Loading ---
49
-
50
- # Load Camel-Doc-OCR-062825
51
- print("Loading Camel-Doc-OCR-062825...")
52
- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
53
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
54
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
55
- MODEL_ID_M,
56
- trust_remote_code=True,
57
- torch_dtype=torch.float16
58
- ).to(device).eval()
59
- print("Camel-Doc-OCR-062825 loaded.")
60
-
61
- # GLM-4.1V-9B-Thinking
62
- print("Loading GLM-4.1V-9B-Thinking")
63
- MODEL_ID_T = "zai-org/GLM-4.1V-9B-Thinking"
64
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
65
- model_t = Glm4vForConditionalGeneration.from_pretrained(
66
- MODEL_ID_T,
67
- trust_remote_code=True,
68
- torch_dtype=torch.float16
69
- ).to(device).eval()
70
- print("GLM-4.1V-9B-Thinking loaded.")
71
-
72
- # Load moondream3
73
- print("Loading moondream3-preview...")
74
- MODEL_ID_MD3 = "moondream/moondream3-preview"
75
- model_md3 = AutoModelForCausalLM.from_pretrained(
76
- MODEL_ID_MD3,
77
- trust_remote_code=True,
78
- torch_dtype=torch.bfloat16,
79
- device_map={"": "cuda"},
80
- )
81
- model_md3.compile()
82
- print("moondream3-preview loaded and compiled.")
83
-
84
-
85
- # --- Moondream3 Utility Functions ---
86
-
87
- def create_annotated_image(image, detection_result, object_name="Object"):
88
- if not isinstance(detection_result, dict) or "objects" not in detection_result:
89
- return image
90
-
91
- original_width, original_height = image.size
92
- annotated_image = np.array(image.convert("RGB"))
93
-
94
- bboxes = []
95
- labels = []
96
-
97
- for i, obj in enumerate(detection_result["objects"]):
98
- x_min = int(obj["x_min"] * original_width)
99
- y_min = int(obj["y_min"] * original_height)
100
- x_max = int(obj["x_max"] * original_width)
101
- y_max = int(obj["y_max"] * original_height)
102
-
103
- x_min = max(0, min(x_min, original_width))
104
- y_min = max(0, min(y_min, original_height))
105
- x_max = max(0, min(x_max, original_width))
106
- y_max = max(0, min(y_max, original_height))
107
-
108
- if x_max > x_min and y_max > y_min:
109
- bboxes.append([x_min, y_min, x_max, y_max])
110
- labels.append(f"{object_name} {i+1}")
111
-
112
- if not bboxes:
113
- return image
114
-
115
- detections = sv.Detections(
116
- xyxy=np.array(bboxes, dtype=np.float32),
117
- class_id=np.arange(len(bboxes))
118
- )
119
-
120
- bounding_box_annotator = sv.BoxAnnotator(
121
- thickness=3,
122
- color_lookup=sv.ColorLookup.INDEX
123
- )
124
- label_annotator = sv.LabelAnnotator(
125
- text_thickness=2,
126
- text_scale=0.6,
127
- color_lookup=sv.ColorLookup.INDEX
128
- )
129
-
130
- annotated_image = bounding_box_annotator.annotate(
131
- scene=annotated_image, detections=detections
132
- )
133
- annotated_image = label_annotator.annotate(
134
- scene=annotated_image, detections=detections, labels=labels
135
- )
136
-
137
- return Image.fromarray(annotated_image)
138
-
139
- def create_point_annotated_image(image, point_result):
140
- if not isinstance(point_result, dict) or "points" not in point_result:
141
- return image
142
-
143
- original_width, original_height = image.size
144
- annotated_image = np.array(image.convert("RGB"))
145
-
146
- points = []
147
- for point in point_result["points"]:
148
- x = int(point["x"] * original_width)
149
- y = int(point["y"] * original_height)
150
- points.append([x, y])
151
-
152
- if points:
153
- points_array = np.array(points).reshape(1, -1, 2)
154
- key_points = sv.KeyPoints(xy=points_array)
155
- vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
156
- annotated_image = vertex_annotator.annotate(
157
- scene=annotated_image, key_points=key_points
158
- )
159
-
160
- return Image.fromarray(annotated_image)
161
-
162
- @spaces.GPU()
163
- def detect_objects_md3(image, prompt, task_type, max_objects):
164
- STANDARD_SIZE = (1024, 1024)
165
- if image is None:
166
- raise gr.Error("Please upload an image.")
167
- image.thumbnail(STANDARD_SIZE)
168
-
169
- t0 = time.perf_counter()
170
-
171
- if task_type == "Object Detection":
172
- settings = {"max_objects": max_objects} if max_objects > 0 else {}
173
- result = model_md3.detect(image, prompt, settings=settings)
174
- annotated_image = create_annotated_image(image, result, prompt)
175
- elif task_type == "Point Detection":
176
- result = model_md3.point(image, prompt)
177
- annotated_image = create_point_annotated_image(image, result)
178
- elif task_type == "Caption":
179
- result = model_md3.caption(image, length="normal")
180
- annotated_image = image
181
- else:
182
- result = model_md3.query(image=image, question=prompt, reasoning=True)
183
- annotated_image = image
184
-
185
- elapsed_ms = (time.perf_counter() - t0) * 1_000
186
-
187
- if isinstance(result, dict):
188
- if "objects" in result:
189
- output_text = f"Found {len(result['objects'])} objects:\n"
190
- for i, obj in enumerate(result['objects'], 1):
191
- output_text += f"\n{i}. Bounding box: ({obj['x_min']:.3f}, {obj['y_min']:.3f}, {obj['x_max']:.3f}, {obj['y_max']:.3f})"
192
- elif "points" in result:
193
- output_text = f"Found {len(result['points'])} points:\n"
194
- for i, point in enumerate(result['points'], 1):
195
- output_text += f"\n{i}. Point: ({point['x']:.3f}, {point['y']:.3f})"
196
- elif "caption" in result:
197
- output_text = result['caption']
198
- elif "answer" in result:
199
- output_text = f"Reasoning: {result.get('reasoning', 'N/A')}\n\nAnswer: {result['answer']}"
200
- else:
201
- output_text = json.dumps(result, indent=2)
202
- else:
203
- output_text = str(result)
204
-
205
- timing_text = f"Inference time: {elapsed_ms:.0f} ms"
206
-
207
- return annotated_image, output_text, timing_text
208
-
209
-
210
- # --- Core Application Logic (for other models) ---
211
- @spaces.GPU
212
- def process_document_stream(
213
- model_name: str,
214
- image: Image.Image,
215
- prompt_input: str,
216
- max_new_tokens: int,
217
- temperature: float,
218
- top_p: float,
219
- top_k: int,
220
- repetition_penalty: float
221
- ):
222
- """
223
- Main generator function for models other than Moondream3.
224
- """
225
- if image is None:
226
- yield "Please upload an image."
227
- return
228
- if not prompt_input or not prompt_input.strip():
229
- yield "Please enter a prompt."
230
- return
231
-
232
- # Select processor and model based on dropdown choice
233
- if model_name == "Camel-Doc-OCR-062825 (OCR)":
234
- processor, model = processor_m, model_m
235
- elif model_name == "GLM-4.1V-9B (Thinking)":
236
- processor, model = processor_t, model_t
237
- else:
238
- yield "Invalid model selected."
239
- return
240
-
241
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_input}]}]
242
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
243
- inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
244
-
245
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
246
-
247
- generation_kwargs = {
248
- **inputs,
249
- "streamer": streamer,
250
- "max_new_tokens": max_new_tokens,
251
- "temperature": temperature,
252
- "top_p": top_p,
253
- "top_k": top_k,
254
- "repetition_penalty": repetition_penalty,
255
- "do_sample": True if temperature > 0 else False
256
- }
257
-
258
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
259
- thread.start()
260
-
261
- buffer = ""
262
- for new_text in streamer:
263
- buffer += new_text
264
- # Clean up potential model-specific tokens
265
- buffer = buffer.replace("<|im_end|>", "").replace("</s>", "")
266
- time.sleep(0.01)
267
- yield buffer
268
-
269
- # --- Gradio UI Definition ---
270
- def create_gradio_interface():
271
- """Builds and returns the Gradio web interface."""
272
- css = """
273
- .main-container { max-width: 1400px; margin: 0 auto; }
274
- .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
275
- .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
276
- .processr-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
277
- .processr-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
278
- """
279
- with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
280
- gr.Markdown("# Multimodal VLM v1.0 ⚡")
281
- gr.Markdown("Explore the capabilities of various Vision Language Models for tasks like OCR, VQA, and Object Detection.")
282
-
283
- with gr.Tabs():
284
- # --- TAB 1: Document and General VLMs ---
285
- with gr.TabItem("📄 Document & General VLM"):
286
- with gr.Row():
287
- with gr.Column(scale=1):
288
- #gr.Markdown("### 1. Configure Inputs")
289
- model_choice = gr.Dropdown(
290
- choices=["Camel-Doc-OCR-062825 (OCR)", "GLM-4.1V-9B (Thinking)"],
291
- label="Select Model", value= "Camel-Doc-OCR-062825 (OCR)"
292
- )
293
- image_input_doc = gr.Image(label="Upload Image", type="pil", sources=['upload'], height=280)
294
- prompt_input_doc = gr.Textbox(label="Query Input", placeholder="e.g., 'Transcribe the text in this document.'")
295
-
296
- with gr.Accordion("Advanced Generation Settings", open=False):
297
- max_new_tokens = gr.Slider(minimum=256, maximum=4096, value=2048, step=128, label="Max New Tokens")
298
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7)
299
- top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9)
300
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=40)
301
- repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
302
-
303
- process_btn = gr.Button("🚀 Process", variant="primary", elem_classes=["process-button"])
304
- clear_btn = gr.Button("🗑️ Clear", variant="secondary")
305
-
306
- with gr.Column(scale=2):
307
- #gr.Markdown("### 2. View Output")
308
- with gr.Tab("Output Stream"):
309
- output_stream = gr.Textbox(label="Model Output", interactive=False, lines=24, show_copy_button=True)
310
-
311
- gr.Examples(
312
- examples=[
313
- ["examples/1.jpg", "Transcribe this receipt."],
314
- ["examples/2.jpg", "Extract the content."],
315
- ["examples/3.jpg", "OCR the image."],
316
- ],
317
- inputs=[image_input_doc, prompt_input_doc]
318
- )
319
-
320
- # --- TAB 2: Moondream3 Lab ---
321
- with gr.TabItem("🌝 Moondream3"):
322
- with gr.Row():
323
- with gr.Column(scale=1):
324
- md3_image_input = gr.Image(label="Upload an image", type="pil", height=400)
325
- md3_task_type = gr.Radio(
326
- choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
327
- label="Task Type", value="Object Detection"
328
- )
329
- md3_prompt_input = gr.Textbox(
330
- label="Prompt (object to detect/question to ask)",
331
- placeholder="e.g., 'car', 'person', 'What's in this image?'"
332
- )
333
- md3_max_objects = gr.Number(
334
- label="Max Objects (for Object Detection only)",
335
- value=10, minimum=1, maximum=50, step=1, visible=True
336
- )
337
- md3_generate_btn = gr.Button(value="🚀 Process", variant="primary", elem_classes=["processr-button"])
338
- with gr.Column(scale=1):
339
- md3_output_image = gr.Image(type="pil", label="Result", height=400)
340
- md3_output_textbox = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
341
- md3_output_time = gr.Markdown()
342
-
343
- gr.Examples(
344
- examples=[
345
- ["md3/1.jpg", "Object Detection", "boats", 7],
346
- ["md3/2.jpg", "Point Detection", "candy", 7],
347
- ["md3/3.png", "Caption", "", 5],
348
- ["md3/4.jpeg", "Visual Question Answering", "Analyze the GDP trend over the yearsAnalyze the GDP trend over the years.", 5],
349
- ],
350
- inputs=[md3_image_input, md3_task_type, md3_prompt_input, md3_max_objects],
351
- label="Click an example to populate inputs"
352
- )
353
-
354
- # --- Event Handlers ---
355
-
356
- # Document Tab
357
- process_btn.click(
358
- fn=process_document_stream,
359
- inputs=[model_choice, image_input_doc, prompt_input_doc, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
360
- outputs=[output_stream]
361
- )
362
- clear_btn.click(lambda: (None, "", ""), outputs=[image_input_doc, prompt_input_doc, output_stream])
363
-
364
- # Moondream3 Tab
365
- def update_max_objects_visibility(task):
366
- return gr.update(visible=(task == "Object Detection"))
367
-
368
- md3_task_type.change(fn=update_max_objects_visibility, inputs=[md3_task_type], outputs=[md3_max_objects])
369
-
370
- md3_generate_btn.click(
371
- fn=detect_objects_md3,
372
- inputs=[md3_image_input, md3_prompt_input, md3_task_type, md3_max_objects],
373
- outputs=[md3_output_image, md3_output_textbox, md3_output_time]
374
- )
375
-
376
- return demo
377
-
378
- if __name__ == "__main__":
379
- demo = create_gradio_interface()
380
- demo.queue(max_size=50).launch(share=True, show_error=True)