Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -251,44 +251,41 @@ def regenerate(state, image_process_mode):
|
|
| 251 |
|
| 252 |
@spaces.GPU
|
| 253 |
def get_interm_outs(state):
|
| 254 |
-
print("HERERERE")
|
| 255 |
-
print(state)
|
| 256 |
prompt = state.get_prompt()
|
| 257 |
-
print(prompt)
|
| 258 |
images = state.get_images(return_pil=True)
|
| 259 |
#prompt, image_args = process_image(prompt, images)
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
|
| 293 |
return images[0], images[0], images[0]
|
| 294 |
|
|
@@ -450,7 +447,7 @@ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as dem
|
|
| 450 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
| 451 |
|
| 452 |
inter_vis_btn.click(
|
| 453 |
-
|
| 454 |
[state],
|
| 455 |
[depth_box, seg_box, gen_box],
|
| 456 |
)
|
|
|
|
| 251 |
|
| 252 |
@spaces.GPU
|
| 253 |
def get_interm_outs(state):
|
|
|
|
|
|
|
| 254 |
prompt = state.get_prompt()
|
|
|
|
| 255 |
images = state.get_images(return_pil=True)
|
| 256 |
#prompt, image_args = process_image(prompt, images)
|
| 257 |
|
| 258 |
+
if images is not None and len(images) > 0:
|
| 259 |
+
if len(images) > 0:
|
| 260 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
| 261 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
| 262 |
|
| 263 |
+
#images = [load_image_from_base64(image) for image in images]
|
| 264 |
+
image_sizes = [image.size for image in images]
|
| 265 |
+
inp_images = process_images(images, image_processor, model.config)
|
| 266 |
+
|
| 267 |
+
if type(inp_images) is list:
|
| 268 |
+
inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
|
| 269 |
+
else:
|
| 270 |
+
inp_images = inp_images.to(model.device, dtype=torch.float16)
|
| 271 |
+
else:
|
| 272 |
+
inp_images = None
|
| 273 |
+
image_sizes = None
|
| 274 |
+
image_args = {"images": inp_images, "image_sizes": image_sizes}
|
| 275 |
+
else:
|
| 276 |
+
inp_images = None
|
| 277 |
+
image_args = {}
|
| 278 |
+
|
| 279 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
| 280 |
+
|
| 281 |
+
interm_outs = model.get_visual_interpretations(
|
| 282 |
+
input_ids,
|
| 283 |
+
**image_args
|
| 284 |
+
)
|
| 285 |
|
| 286 |
+
depth_outs = get_depth_images(interm_outs, image_sizes[0])
|
| 287 |
+
seg_outs = get_seg_images(interm_outs, images[0])
|
| 288 |
+
gen_outs = get_gen_images(interm_outs)
|
| 289 |
|
| 290 |
return images[0], images[0], images[0]
|
| 291 |
|
|
|
|
| 447 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
| 448 |
|
| 449 |
inter_vis_btn.click(
|
| 450 |
+
get_interm_outs,
|
| 451 |
[state],
|
| 452 |
[depth_box, seg_box, gen_box],
|
| 453 |
)
|