yuhangzang commited on
Commit
8327c64
·
1 Parent(s): b4bcbcf
Files changed (1) hide show
  1. app.py +65 -48
app.py CHANGED
@@ -25,16 +25,14 @@ def load_model():
25
  device = get_device()
26
  dtype = select_dtype(device)
27
 
 
28
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
29
  MODEL_ID,
30
  torch_dtype=dtype,
31
- device_map="auto" if device == "cuda" else None,
32
  trust_remote_code=True,
33
  )
34
 
35
- if device != "cuda":
36
- model.to(device)
37
-
38
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
39
  return model, processor
40
 
@@ -48,47 +46,64 @@ def generate_caption(image: Image.Image):
48
  if image is None:
49
  return "", 0
50
 
51
- device = MODEL.device
52
- messages = [
53
- {
54
- "role": "user",
55
- "content": [
56
- {"type": "image"},
57
- {"type": "text", "text": DEFAULT_PROMPT},
58
- ],
59
- }
60
- ]
61
-
62
- prompt_text = PROCESSOR.apply_chat_template(
63
- messages, tokenize=False, add_generation_prompt=True
64
- )
65
-
66
- inputs = PROCESSOR(
67
- text=[prompt_text],
68
- images=[image],
69
- return_tensors="pt",
70
- ).to(device)
71
-
72
- generated_ids = MODEL.generate(
73
- **inputs,
74
- max_new_tokens=MAX_NEW_TOKENS,
75
- do_sample=False,
76
- )
77
-
78
- generated_ids_trimmed = [
79
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
80
- ]
81
- output_text = PROCESSOR.batch_decode(
82
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
83
- )
84
- caption = output_text[0].strip()
85
-
86
- input_ids = inputs.get("input_ids")
87
- input_length = input_ids.shape[-1] if input_ids is not None else 0
88
- total_length = generated_ids.shape[-1]
89
- num_generated_tokens = max(total_length - input_length, 0)
90
-
91
- return caption, int(num_generated_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  with gr.Blocks(title="CapRL Image Captioning") as demo:
@@ -109,19 +124,21 @@ with gr.Blocks(title="CapRL Image Captioning") as demo:
109
  image_input = gr.Image(type="pil", label="Input Image")
110
  generate_button = gr.Button("Generate Caption")
111
  with gr.Column():
112
- caption_output = gr.Textbox(label="Caption", lines=6, show_copy_button=True)
113
  token_output = gr.Number(label="Generated Tokens", precision=0)
114
 
115
  generate_button.click(
116
  fn=generate_caption,
117
  inputs=image_input,
118
  outputs=[caption_output, token_output],
 
119
  )
120
 
121
  image_input.upload(
122
  fn=generate_caption,
123
  inputs=image_input,
124
  outputs=[caption_output, token_output],
 
125
  )
126
 
127
  gr.Examples(
@@ -133,7 +150,7 @@ with gr.Blocks(title="CapRL Image Captioning") as demo:
133
  inputs=image_input,
134
  outputs=[caption_output, token_output],
135
  fn=generate_caption,
136
- cache_examples=False,
137
  label="📸 Example Images"
138
  )
139
 
@@ -147,7 +164,7 @@ with gr.Blocks(title="CapRL Image Captioning") as demo:
147
  year={2025}
148
  }"""
149
 
150
- gr.Code(value=citation_text, language="bibtex", label="BibTeX Citation", show_copy_button=True)
151
 
152
 
153
  demo.launch()
 
25
  device = get_device()
26
  dtype = select_dtype(device)
27
 
28
+ # Use device_map="auto" for proper GPU allocation with spaces.GPU decorator
29
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
30
  MODEL_ID,
31
  torch_dtype=dtype,
32
+ device_map="auto",
33
  trust_remote_code=True,
34
  )
35
 
 
 
 
36
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
37
  return model, processor
38
 
 
46
  if image is None:
47
  return "", 0
48
 
49
+ try:
50
+ # Validate image
51
+ if not isinstance(image, Image.Image):
52
+ return "Error: Invalid image format", 0
53
+
54
+ # Check image size (warn if too large)
55
+ max_size = 4096
56
+ if image.width > max_size or image.height > max_size:
57
+ # Resize if too large to prevent OOM
58
+ image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
59
+
60
+ device = MODEL.device
61
+ messages = [
62
+ {
63
+ "role": "user",
64
+ "content": [
65
+ {"type": "image"},
66
+ {"type": "text", "text": DEFAULT_PROMPT},
67
+ ],
68
+ }
69
+ ]
70
+
71
+ prompt_text = PROCESSOR.apply_chat_template(
72
+ messages, tokenize=False, add_generation_prompt=True
73
+ )
74
+
75
+ inputs = PROCESSOR(
76
+ text=[prompt_text],
77
+ images=[image],
78
+ return_tensors="pt",
79
+ ).to(device)
80
+
81
+ generated_ids = MODEL.generate(
82
+ **inputs,
83
+ max_new_tokens=MAX_NEW_TOKENS,
84
+ do_sample=False,
85
+ )
86
+
87
+ generated_ids_trimmed = [
88
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
89
+ ]
90
+ output_text = PROCESSOR.batch_decode(
91
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
92
+ )
93
+ caption = output_text[0].strip()
94
+
95
+ input_ids = inputs.get("input_ids")
96
+ input_length = input_ids.shape[-1] if input_ids is not None else 0
97
+ total_length = generated_ids.shape[-1]
98
+ num_generated_tokens = max(total_length - input_length, 0)
99
+
100
+ return caption, int(num_generated_tokens)
101
+
102
+ except torch.cuda.OutOfMemoryError:
103
+ torch.cuda.empty_cache()
104
+ return "Error: Out of GPU memory. Please try with a smaller image.", 0
105
+ except Exception as e:
106
+ return f"Error generating caption: {str(e)}", 0
107
 
108
 
109
  with gr.Blocks(title="CapRL Image Captioning") as demo:
 
124
  image_input = gr.Image(type="pil", label="Input Image")
125
  generate_button = gr.Button("Generate Caption")
126
  with gr.Column():
127
+ caption_output = gr.Textbox(label="Caption", lines=6)
128
  token_output = gr.Number(label="Generated Tokens", precision=0)
129
 
130
  generate_button.click(
131
  fn=generate_caption,
132
  inputs=image_input,
133
  outputs=[caption_output, token_output],
134
+ show_progress=True,
135
  )
136
 
137
  image_input.upload(
138
  fn=generate_caption,
139
  inputs=image_input,
140
  outputs=[caption_output, token_output],
141
+ show_progress=True,
142
  )
143
 
144
  gr.Examples(
 
150
  inputs=image_input,
151
  outputs=[caption_output, token_output],
152
  fn=generate_caption,
153
+ cache_examples=True,
154
  label="📸 Example Images"
155
  )
156
 
 
164
  year={2025}
165
  }"""
166
 
167
+ gr.Code(value=citation_text, language="bibtex", label="BibTeX Citation")
168
 
169
 
170
  demo.launch()