prithivMLmods commited on
Commit
4a363da
·
verified ·
1 Parent(s): aa59ead

update app

Browse files
Files changed (1) hide show
  1. app.py +346 -145
app.py CHANGED
@@ -5,7 +5,9 @@ import json
5
  import time
6
  import asyncio
7
  from threading import Thread
8
- from typing import Iterable
 
 
9
 
10
  import gradio as gr
11
  import spaces
@@ -13,40 +15,44 @@ import torch
13
  import numpy as np
14
  from PIL import Image
15
  import cv2
 
 
16
 
17
  from transformers import (
18
  Qwen2_5_VLForConditionalGeneration,
19
  Qwen3VLForConditionalGeneration,
20
- AutoTokenizer,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
  )
24
  from transformers.image_utils import load_image
 
25
  from gradio.themes import Soft
26
  from gradio.themes.utils import colors, fonts, sizes
27
 
28
- colors.steel_blue = colors.Color(
29
- name="steel_blue",
30
- c50="#EBF3F8",
31
- c100="#D3E5F0",
32
- c200="#A8CCE1",
33
- c300="#7DB3D2",
34
- c400="#529AC3",
35
- c500="#4682B4", # SteelBlue base color
36
- c600="#3E72A0",
37
- c700="#36638C",
38
- c800="#2E5378",
39
- c900="#264364",
40
- c950="#1E3450",
41
- )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- class SteelBlueTheme(Soft):
45
  def __init__(
46
  self,
47
  *,
48
  primary_hue: colors.Color | str = colors.gray,
49
- secondary_hue: colors.Color | str = colors.steel_blue,
50
  neutral_hue: colors.Color | str = colors.slate,
51
  text_size: sizes.Size | str = sizes.text_lg,
52
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -73,8 +79,8 @@ class SteelBlueTheme(Soft):
73
  button_primary_text_color_hover="white",
74
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
75
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
76
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
77
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
78
  button_secondary_text_color="black",
79
  button_secondary_text_color_hover="white",
80
  button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
@@ -92,51 +98,104 @@ class SteelBlueTheme(Soft):
92
  block_label_background_fill="*primary_200",
93
  )
94
 
95
- steel_blue_theme = SteelBlueTheme()
 
 
 
 
 
 
 
 
 
 
96
 
97
  MAX_MAX_NEW_TOKENS = 4096
98
  DEFAULT_MAX_NEW_TOKENS = 1024
99
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
100
 
101
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
102
 
103
- # Load Qwen2.5-VL-7B-Instruct
104
- MODEL_ID_M = "Qwen/Qwen2.5-VL-7B-Instruct"
105
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
106
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
107
- MODEL_ID_M,
 
 
108
  trust_remote_code=True,
109
- torch_dtype=torch.float16
110
  ).to(device).eval()
111
 
112
- # Load Qwen2.5-VL-3B-Instruct
113
- MODEL_ID_X = "Qwen/Qwen2.5-VL-3B-Instruct"
114
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
115
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
116
- MODEL_ID_X,
117
  trust_remote_code=True,
118
- torch_dtype=torch.float16
119
  ).to(device).eval()
120
 
121
  # Load Qwen3-VL-2B-Instruct
122
- MODEL_ID_Q = "Qwen/Qwen3-VL-2B-Instruct"
123
- processor_q = AutoProcessor.from_pretrained(MODEL_ID_Q, trust_remote_code=True)
124
- model_q = Qwen3VLForConditionalGeneration.from_pretrained(
125
- MODEL_ID_Q,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  trust_remote_code=True,
127
  torch_dtype=torch.float16
128
  ).to(device).eval()
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def downsample_video(video_path):
131
- """
132
- Downsamples the video to evenly spaced frames.
133
- Each frame is returned as a PIL image along with its timestamp.
134
- """
135
  vidcap = cv2.VideoCapture(video_path)
136
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
137
- fps = vidcap.get(cv2.CAP_PROP_FPS)
138
  frames = []
139
- # Use a maximum of 10 frames to avoid excessive memory usage
140
  frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
141
  for i in frame_indices:
142
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
@@ -144,39 +203,75 @@ def downsample_video(video_path):
144
  if success:
145
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
146
  pil_image = Image.fromarray(image)
147
- timestamp = round(i / fps, 2)
148
- frames.append((pil_image, timestamp))
149
  vidcap.release()
150
  return frames
151
 
152
- @spaces.GPU
153
- def generate_image(model_name: str, text: str, image: Image.Image,
154
- max_new_tokens: int = 1024,
155
- temperature: float = 0.6,
156
- top_p: float = 0.9,
157
- top_k: int = 50,
158
- repetition_penalty: float = 1.2):
159
- """
160
- Generates responses using the selected model for image input.
161
- """
162
- if model_name == "Qwen2.5-VL-7B-Instruct":
163
- processor, model = processor_m, model_m
164
- elif model_name == "Qwen2.5-VL-3B-Instruct":
165
- processor, model = processor_x, model_x
166
- elif model_name == "Qwen3-VL-2B-Instruct":
167
- processor, model = processor_q, model_q
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
- yield "Invalid model selected.", "Invalid model selected."
170
- return
 
 
 
171
 
 
 
 
 
172
  if image is None:
173
  yield "Please upload an image.", "Please upload an image."
174
  return
175
-
 
 
 
 
 
176
  messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
177
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
178
- inputs = processor(
179
- text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
180
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
181
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
182
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
@@ -188,86 +283,143 @@ def generate_image(model_name: str, text: str, image: Image.Image,
188
  yield buffer, buffer
189
 
190
  @spaces.GPU
191
- def generate_video(model_name: str, text: str, video_path: str,
192
- max_new_tokens: int = 1024,
193
- temperature: float = 0.6,
194
- top_p: float = 0.9,
195
- top_k: int = 50,
196
- repetition_penalty: float = 1.2):
197
- """
198
- Generates responses using the selected model for video input.
199
- """
200
- if model_name == "Qwen2.5-VL-7B-Instruct":
201
- processor, model = processor_m, model_m
202
- elif model_name == "Qwen2.5-VL-3B-Instruct":
203
- processor, model = processor_x, model_x
204
- elif model_name == "Qwen3-VL-2B-Instruct":
205
- processor, model = processor_q, model_q
206
- else:
207
- yield "Invalid model selected.", "Invalid model selected."
208
- return
209
-
210
  if video_path is None:
211
  yield "Please upload a video.", "Please upload a video."
212
  return
 
 
 
 
 
213
 
214
- frames_with_ts = downsample_video(video_path)
215
- if not frames_with_ts:
216
  yield "Could not process video.", "Could not process video."
217
  return
218
-
219
  messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
220
- images_for_processor = []
221
- for frame, timestamp in frames_with_ts:
222
- messages[0]["content"].append({"type": "image"})
223
- images_for_processor.append(frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
226
- inputs = processor(
227
- text=[prompt_full], images=images_for_processor, return_tensors="pt", padding=True).to(device)
228
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
229
- generation_kwargs = {
230
- **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens,
231
- "do_sample": True, "temperature": temperature, "top_p": top_p,
232
- "top_k": top_k, "repetition_penalty": repetition_penalty,
233
- }
234
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
235
  thread.start()
236
  buffer = ""
237
  for new_text in streamer:
238
  buffer += new_text
239
- #buffer = buffer.replace("<|im_end|>", "")
240
  time.sleep(0.01)
241
  yield buffer, buffer
242
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- # Define examples for image and video inference
245
- image_examples = [
246
- ["Explain the content in detail.", "images/D.jpg"],
247
- ["Explain the content (ocr).", "images/O.jpg"],
248
- ["What is the core meaning of the poem?", "images/S.jpg"],
249
- ["Provide a detailed caption for the image.", "images/A.jpg"],
250
- #["Explain the pie-chart in detail.", "images/2.jpg"],
251
- #["Jsonify Data.", "images/1.jpg"],
252
- ]
 
 
 
 
 
 
 
 
 
 
253
 
254
- video_examples = [
255
- ["Explain the ad in detail", "videos/1.mp4"],
256
- ["Identify the main actions in the video", "videos/2.mp4"],
257
- ]
258
 
259
- css = """
260
- #main-title h1 {
261
- font-size: 2.3em !important;
262
- }
263
- #output-title h2 {
264
- font-size: 2.1em !important;
265
- }
266
- """
 
 
 
267
 
268
- # Create the Gradio Interface
269
- with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
270
- gr.Markdown("# **Qwen3-VL-Outpost**", elem_id="main-title")
271
  with gr.Row():
272
  with gr.Column(scale=2):
273
  with gr.Tabs():
@@ -276,42 +428,91 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
276
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
277
  image_submit = gr.Button("Submit", variant="primary")
278
  gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
279
-
280
  with gr.TabItem("Video Inference"):
281
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
282
- video_upload = gr.Video(label="Upload Video", height=290)
283
  video_submit = gr.Button("Submit", variant="primary")
284
  gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
285
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  with gr.Accordion("Advanced options", open=False):
287
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
288
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
289
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
290
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
291
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
292
-
293
  with gr.Column(scale=3):
294
  gr.Markdown("## Output", elem_id="output-title")
295
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
296
  with gr.Accordion("(Result.md)", open=False):
297
- markdown_output = gr.Markdown()
298
-
 
 
 
299
  model_choice = gr.Radio(
300
- choices=["Qwen3-VL-2B-Instruct", "Qwen2.5-VL-3B-Instruct", "Qwen2.5-VL-7B-Instruct"],
 
 
 
 
 
 
301
  label="Select Model",
302
- value="Qwen3-VL-2B-Instruct"
303
  )
304
 
305
- image_submit.click(
306
- fn=generate_image,
307
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
308
- outputs=[output, markdown_output]
309
- )
310
- video_submit.click(
311
- fn=generate_video,
312
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
313
- outputs=[output, markdown_output]
314
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  if __name__ == "__main__":
317
  demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)
 
5
  import time
6
  import asyncio
7
  from threading import Thread
8
+ from pathlib import Path
9
+ from io import BytesIO
10
+ from typing import Optional, Tuple, Dict, Any, Iterable
11
 
12
  import gradio as gr
13
  import spaces
 
15
  import numpy as np
16
  from PIL import Image
17
  import cv2
18
+ import requests
19
+ import fitz
20
 
21
  from transformers import (
22
  Qwen2_5_VLForConditionalGeneration,
23
  Qwen3VLForConditionalGeneration,
 
24
  AutoProcessor,
25
  TextIteratorStreamer,
26
  )
27
  from transformers.image_utils import load_image
28
+
29
  from gradio.themes import Soft
30
  from gradio.themes.utils import colors, fonts, sizes
31
 
32
+ # --- Theme and CSS Definition ---
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Define the new OrangeRed color palette
35
+ colors.orange_red = colors.Color(
36
+ name="orange_red",
37
+ c50="#FFF0E5",
38
+ c100="#FFE0CC",
39
+ c200="#FFC299",
40
+ c300="#FFA366",
41
+ c400="#FF8533",
42
+ c500="#FF4500", # OrangeRed base color
43
+ c600="#E63E00",
44
+ c700="#CC3700",
45
+ c800="#B33000",
46
+ c900="#992900",
47
+ c950="#802200",
48
+ )
49
 
50
+ class OrangeRedTheme(Soft):
51
  def __init__(
52
  self,
53
  *,
54
  primary_hue: colors.Color | str = colors.gray,
55
+ secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
56
  neutral_hue: colors.Color | str = colors.slate,
57
  text_size: sizes.Size | str = sizes.text_lg,
58
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
79
  button_primary_text_color_hover="white",
80
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
81
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
82
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
83
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
84
  button_secondary_text_color="black",
85
  button_secondary_text_color_hover="white",
86
  button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
 
98
  block_label_background_fill="*primary_200",
99
  )
100
 
101
+ # Instantiate the new theme
102
+ orange_red_theme = OrangeRedTheme()
103
+
104
+ css = """
105
+ #main-title h1 {
106
+ font-size: 2.3em !important;
107
+ }
108
+ #output-title h2 {
109
+ font-size: 2.1em !important;
110
+ }
111
+ """
112
 
113
  MAX_MAX_NEW_TOKENS = 4096
114
  DEFAULT_MAX_NEW_TOKENS = 1024
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
 
117
+ print("Using device:", device)
118
 
119
+ # --- Model Loading ---
120
+
121
+ # Load Qwen3-VL-4B-Instruct
122
+ MODEL_ID_Q4B = "Qwen/Qwen3-VL-4B-Instruct"
123
+ processor_q4b = AutoProcessor.from_pretrained(MODEL_ID_Q4B, trust_remote_code=True)
124
+ model_q4b = Qwen3VLForConditionalGeneration.from_pretrained(
125
+ MODEL_ID_Q4B,
126
  trust_remote_code=True,
127
+ torch_dtype=torch.bfloat16
128
  ).to(device).eval()
129
 
130
+ # Load Qwen3-VL-8B-Instruct
131
+ MODEL_ID_Q8B = "Qwen/Qwen3-VL-8B-Instruct"
132
+ processor_q8b = AutoProcessor.from_pretrained(MODEL_ID_Q8B, trust_remote_code=True)
133
+ model_q8b = Qwen3VLForConditionalGeneration.from_pretrained(
134
+ MODEL_ID_Q8B,
135
  trust_remote_code=True,
136
+ torch_dtype=torch.bfloat16
137
  ).to(device).eval()
138
 
139
  # Load Qwen3-VL-2B-Instruct
140
+ MODEL_ID_Q2B = "Qwen/Qwen3-VL-2B-Instruct"
141
+ processor_q2b = AutoProcessor.from_pretrained(MODEL_ID_Q2B, trust_remote_code=True)
142
+ model_q2b = Qwen3VLForConditionalGeneration.from_pretrained(
143
+ MODEL_ID_Q2B,
144
+ trust_remote_code=True,
145
+ torch_dtype=torch.bfloat16
146
+ ).to(device).eval()
147
+
148
+ # Load Qwen2.5-VL-7B-Instruct
149
+ MODEL_ID_M7B = "Qwen/Qwen2.5-VL-7B-Instruct"
150
+ processor_m7b = AutoProcessor.from_pretrained(MODEL_ID_M7B, trust_remote_code=True)
151
+ model_m7b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
152
+ MODEL_ID_M7B,
153
+ trust_remote_code=True,
154
+ torch_dtype=torch.float16
155
+ ).to(device).eval()
156
+
157
+ # Load Qwen2.5-VL-3B-Instruct
158
+ MODEL_ID_X3B = "Qwen/Qwen2.5-VL-3B-Instruct"
159
+ processor_x3b = AutoProcessor.from_pretrained(MODEL_ID_X3B, trust_remote_code=True)
160
+ model_x3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
161
+ MODEL_ID_X3B,
162
  trust_remote_code=True,
163
  torch_dtype=torch.float16
164
  ).to(device).eval()
165
 
166
+
167
+ # --- Helper Functions ---
168
+
169
+ def select_model(model_name: str):
170
+ if model_name == "Qwen3-VL-4B-Instruct":
171
+ return processor_q4b, model_q4b
172
+ elif model_name == "Qwen3-VL-8B-Instruct":
173
+ return processor_q8b, model_q8b
174
+ elif model_name == "Qwen3-VL-2B-Instruct":
175
+ return processor_q2b, model_q2b
176
+ elif model_name == "Qwen2.5-VL-7B-Instruct":
177
+ return processor_m7b, model_m7b
178
+ elif model_name == "Qwen2.5-VL-3B-Instruct":
179
+ return processor_x3b, model_x3b
180
+ else:
181
+ raise ValueError("Invalid model selected.")
182
+
183
+ def extract_gif_frames(gif_path: str):
184
+ if not gif_path:
185
+ return []
186
+ with Image.open(gif_path) as gif:
187
+ total_frames = gif.n_frames
188
+ frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
189
+ frames = []
190
+ for i in frame_indices:
191
+ gif.seek(i)
192
+ frames.append(gif.convert("RGB").copy())
193
+ return frames
194
+
195
  def downsample_video(video_path):
 
 
 
 
196
  vidcap = cv2.VideoCapture(video_path)
197
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
198
  frames = []
 
199
  frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
200
  for i in frame_indices:
201
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
 
203
  if success:
204
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
205
  pil_image = Image.fromarray(image)
206
+ frames.append(pil_image)
 
207
  vidcap.release()
208
  return frames
209
 
210
+ def convert_pdf_to_images(file_path: str, dpi: int = 200):
211
+ if not file_path:
212
+ return []
213
+ images = []
214
+ pdf_document = fitz.open(file_path)
215
+ zoom = dpi / 72.0
216
+ mat = fitz.Matrix(zoom, zoom)
217
+ for page_num in range(len(pdf_document)):
218
+ page = pdf_document.load_page(page_num)
219
+ pix = page.get_pixmap(matrix=mat)
220
+ img_data = pix.tobytes("png")
221
+ images.append(Image.open(BytesIO(img_data)))
222
+ pdf_document.close()
223
+ return images
224
+
225
+ def get_initial_pdf_state() -> Dict[str, Any]:
226
+ return {"pages": [], "total_pages": 0, "current_page_index": 0}
227
+
228
+ def load_and_preview_pdf(file_path: Optional[str]) -> Tuple[Optional[Image.Image], Dict[str, Any], str]:
229
+ state = get_initial_pdf_state()
230
+ if not file_path:
231
+ return None, state, '<div style="text-align:center;">No file loaded</div>'
232
+ try:
233
+ pages = convert_pdf_to_images(file_path)
234
+ if not pages:
235
+ return None, state, '<div style="text-align:center;">Could not load file</div>'
236
+ state["pages"] = pages
237
+ state["total_pages"] = len(pages)
238
+ page_info_html = f'<div style="text-align:center;">Page 1 / {state["total_pages"]}</div>'
239
+ return pages[0], state, page_info_html
240
+ except Exception as e:
241
+ return None, state, f'<div style="text-align:center;">Failed to load preview: {e}</div>'
242
+
243
+ def navigate_pdf_page(direction: str, state: Dict[str, Any]):
244
+ if not state or not state["pages"]:
245
+ return None, state, '<div style="text-align:center;">No file loaded</div>'
246
+ current_index = state["current_page_index"]
247
+ total_pages = state["total_pages"]
248
+ if direction == "prev":
249
+ new_index = max(0, current_index - 1)
250
+ elif direction == "next":
251
+ new_index = min(total_pages - 1, current_index + 1)
252
  else:
253
+ new_index = current_index
254
+ state["current_page_index"] = new_index
255
+ image_preview = state["pages"][new_index]
256
+ page_info_html = f'<div style="text-align:center;">Page {new_index + 1} / {total_pages}</div>'
257
+ return image_preview, state, page_info_html
258
 
259
+ # --- Generation Functions ---
260
+
261
+ @spaces.GPU
262
+ def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
263
  if image is None:
264
  yield "Please upload an image.", "Please upload an image."
265
  return
266
+ try:
267
+ processor, model = select_model(model_name)
268
+ except ValueError as e:
269
+ yield str(e), str(e)
270
+ return
271
+
272
  messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
273
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
274
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
 
275
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
276
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
277
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
283
  yield buffer, buffer
284
 
285
  @spaces.GPU
286
+ def generate_video(model_name: str, text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  if video_path is None:
288
  yield "Please upload a video.", "Please upload a video."
289
  return
290
+ try:
291
+ processor, model = select_model(model_name)
292
+ except ValueError as e:
293
+ yield str(e), str(e)
294
+ return
295
 
296
+ frames = downsample_video(video_path)
297
+ if not frames:
298
  yield "Could not process video.", "Could not process video."
299
  return
300
+
301
  messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
302
+ for frame in frames:
303
+ messages[0]["content"].insert(0, {"type": "image"})
304
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
305
+ inputs = processor(text=[prompt_full], images=frames, return_tensors="pt", padding=True).to(device)
306
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
307
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
308
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
309
+ thread.start()
310
+ buffer = ""
311
+ for new_text in streamer:
312
+ buffer += new_text
313
+ buffer = buffer.replace("<|im_end|>", "")
314
+ time.sleep(0.01)
315
+ yield buffer, buffer
316
+
317
+ @spaces.GPU
318
+ def generate_pdf(model_name: str, text: str, state: Dict[str, Any], max_new_tokens: int = 2048, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
319
+ if not state or not state["pages"]:
320
+ yield "Please upload a PDF file first.", "Please upload a PDF file first."
321
+ return
322
+ try:
323
+ processor, model = select_model(model_name)
324
+ except ValueError as e:
325
+ yield str(e), str(e)
326
+ return
327
 
328
+ page_images = state["pages"]
329
+ full_response = ""
330
+ for i, image in enumerate(page_images):
331
+ page_header = f"--- Page {i+1}/{len(page_images)} ---\n"
332
+ yield full_response + page_header, full_response + page_header
333
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
334
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
335
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
336
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
337
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
338
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
339
+ thread.start()
340
+ page_buffer = ""
341
+ for new_text in streamer:
342
+ page_buffer += new_text
343
+ yield full_response + page_header + page_buffer, full_response + page_header + page_buffer
344
+ time.sleep(0.01)
345
+ full_response += page_header + page_buffer + "\n\n"
346
+
347
+ @spaces.GPU
348
+ def generate_caption(model_name: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
349
+ if image is None:
350
+ yield "Please upload an image to caption.", "Please upload an image to caption."
351
+ return
352
+ try:
353
+ processor, model = select_model(model_name)
354
+ except ValueError as e:
355
+ yield str(e), str(e)
356
+ return
357
+
358
+ system_prompt = (
359
+ "You are an AI assistant. For the given image, write a precise caption and provide a structured set of "
360
+ "attributes describing visual elements like objects, people, actions, colors, and environment."
361
+ )
362
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": system_prompt}]}]
363
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
364
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
 
365
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
366
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
 
 
 
 
367
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
368
  thread.start()
369
  buffer = ""
370
  for new_text in streamer:
371
  buffer += new_text
 
372
  time.sleep(0.01)
373
  yield buffer, buffer
374
 
375
+ @spaces.GPU
376
+ def generate_gif(model_name: str, text: str, gif_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
377
+ if gif_path is None:
378
+ yield "Please upload a GIF.", "Please upload a GIF."
379
+ return
380
+ try:
381
+ processor, model = select_model(model_name)
382
+ except ValueError as e:
383
+ yield str(e), str(e)
384
+ return
385
 
386
+ frames = extract_gif_frames(gif_path)
387
+ if not frames:
388
+ yield "Could not process GIF.", "Could not process GIF."
389
+ return
390
+ messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
391
+ for frame in frames:
392
+ messages[0]["content"].insert(0, {"type": "image"})
393
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
394
+ inputs = processor(text=[prompt_full], images=frames, return_tensors="pt", padding=True).to(device)
395
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
396
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
397
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
398
+ thread.start()
399
+ buffer = ""
400
+ for new_text in streamer:
401
+ buffer += new_text
402
+ buffer = buffer.replace("<|im_end|>", "")
403
+ time.sleep(0.01)
404
+ yield buffer, buffer
405
 
406
+ # --- Examples and Gradio UI ---
 
 
 
407
 
408
+ image_examples = [["Perform OCR on the image...", "examples/images/1.jpg"],
409
+ ["Caption the image. Describe the safety measures shown in the image. Conclude whether the situation is (safe or unsafe)...", "examples/images/2.jpg"],
410
+ ["Solve the problem...", "examples/images/3.png"]]
411
+ video_examples = [["Explain the Ad video in detail.", "examples/videos/1.mp4"],
412
+ ["Explain the video in detail.", "examples/videos/2.mp4"]]
413
+ pdf_examples = [["Extract the content precisely.", "examples/pdfs/doc1.pdf"],
414
+ ["Analyze and provide a short report.", "examples/pdfs/doc2.pdf"]]
415
+ gif_examples = [["Describe this GIF.", "examples/gifs/1.gif"],
416
+ ["Describe this GIF.", "examples/gifs/2.gif"]]
417
+ caption_examples = [["examples/captions/1.JPG"],
418
+ ["examples/captions/2.jpeg"], ["examples/captions/3.jpeg"]]
419
 
420
+ with gr.Blocks(theme=orange_red_theme, css=css) as demo:
421
+ pdf_state = gr.State(value=get_initial_pdf_state())
422
+ gr.Markdown("# **Qwen-VL: Multimodal Outpost**", elem_id="main-title")
423
  with gr.Row():
424
  with gr.Column(scale=2):
425
  with gr.Tabs():
 
428
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
429
  image_submit = gr.Button("Submit", variant="primary")
430
  gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
431
+
432
  with gr.TabItem("Video Inference"):
433
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
434
+ video_upload = gr.Video(label="Upload Video(≤30s)", height=290)
435
  video_submit = gr.Button("Submit", variant="primary")
436
  gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
437
+
438
+ with gr.TabItem("PDF Inference"):
439
+ with gr.Row():
440
+ with gr.Column(scale=1):
441
+ pdf_query = gr.Textbox(label="Query Input", placeholder="e.g., 'Summarize this document'")
442
+ pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
443
+ pdf_submit = gr.Button("Submit", variant="primary")
444
+ with gr.Column(scale=1):
445
+ pdf_preview_img = gr.Image(label="PDF Preview", height=290)
446
+ with gr.Row():
447
+ prev_page_btn = gr.Button("◀ Previous")
448
+ page_info = gr.HTML('<div style="text-align:center;">No file loaded</div>')
449
+ next_page_btn = gr.Button("Next ▶")
450
+ gr.Examples(examples=pdf_examples, inputs=[pdf_query, pdf_upload])
451
+
452
+ with gr.TabItem("Gif Inference"):
453
+ gif_query = gr.Textbox(label="Query Input", placeholder="e.g., 'What is happening in this gif?'")
454
+ gif_upload = gr.Image(type="filepath", label="Upload GIF", height=290)
455
+ gif_submit = gr.Button("Submit", variant="primary")
456
+ gr.Examples(examples=gif_examples, inputs=[gif_query, gif_upload])
457
+
458
+ with gr.TabItem("Caption"):
459
+ caption_image_upload = gr.Image(type="pil", label="Image to Caption", height=290)
460
+ caption_submit = gr.Button("Generate Caption", variant="primary")
461
+ gr.Examples(examples=caption_examples, inputs=[caption_image_upload])
462
+
463
  with gr.Accordion("Advanced options", open=False):
464
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
465
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
466
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
467
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
468
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
469
+
470
  with gr.Column(scale=3):
471
  gr.Markdown("## Output", elem_id="output-title")
472
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=14, show_copy_button=True)
473
  with gr.Accordion("(Result.md)", open=False):
474
+ markdown_output = gr.Markdown(label="(Result.Md)", latex_delimiters=[
475
+ {"left": "$$", "right": "$$", "display": True},
476
+ {"left": "$", "right": "$", "display": False}
477
+ ])
478
+
479
  model_choice = gr.Radio(
480
+ choices=[
481
+ "Qwen3-VL-4B-Instruct",
482
+ "Qwen3-VL-8B-Instruct",
483
+ "Qwen3-VL-2B-Instruct",
484
+ "Qwen2.5-VL-7B-Instruct",
485
+ "Qwen2.5-VL-3B-Instruct"
486
+ ],
487
  label="Select Model",
488
+ value="Qwen3-VL-4B-Instruct"
489
  )
490
 
491
+ # --- Event Handlers ---
492
+
493
+ image_submit.click(fn=generate_image,
494
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
495
+ outputs=[output, markdown_output])
496
+
497
+ video_submit.click(fn=generate_video,
498
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
499
+ outputs=[output, markdown_output])
500
+
501
+ pdf_submit.click(fn=generate_pdf,
502
+ inputs=[model_choice, pdf_query, pdf_state, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
503
+ outputs=[output, markdown_output])
504
+
505
+ gif_submit.click(fn=generate_gif,
506
+ inputs=[model_choice, gif_query, gif_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
507
+ outputs=[output, markdown_output])
508
+
509
+ caption_submit.click(fn=generate_caption,
510
+ inputs=[model_choice, caption_image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
511
+ outputs=[output, markdown_output])
512
+
513
+ pdf_upload.change(fn=load_and_preview_pdf, inputs=[pdf_upload], outputs=[pdf_preview_img, pdf_state, page_info])
514
+ prev_page_btn.click(fn=lambda s: navigate_pdf_page("prev", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
515
+ next_page_btn.click(fn=lambda s: navigate_pdf_page("next", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
516
 
517
  if __name__ == "__main__":
518
  demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)