Spaces:
Runtime error
Runtime error
update
Browse files- models/vsa_model.py +3 -93
models/vsa_model.py
CHANGED
|
@@ -298,12 +298,12 @@ class VisionSearchAssistant:
|
|
| 298 |
self.use_correlate = True
|
| 299 |
|
| 300 |
@spaces.GPU
|
| 301 |
-
def
|
| 302 |
self,
|
| 303 |
image: Union[str, Image.Image, np.ndarray],
|
| 304 |
text: str,
|
| 305 |
-
ground_classes:
|
| 306 |
-
):
|
| 307 |
self.searcher = WebSearcher(
|
| 308 |
model_path = self.search_model
|
| 309 |
)
|
|
@@ -318,96 +318,6 @@ class VisionSearchAssistant:
|
|
| 318 |
load_8bit = self.vlm_load_8bit
|
| 319 |
)
|
| 320 |
|
| 321 |
-
# Create and clear the temporary directory.
|
| 322 |
-
if not os.access('temp', os.F_OK):
|
| 323 |
-
os.makedirs('temp')
|
| 324 |
-
for file in os.listdir('temp'):
|
| 325 |
-
os.remove(os.path.join('temp', file))
|
| 326 |
-
|
| 327 |
-
with open('temp/text.txt', 'w', encoding='utf-8') as wf:
|
| 328 |
-
wf.write(text)
|
| 329 |
-
|
| 330 |
-
# Load Image
|
| 331 |
-
if isinstance(image, str):
|
| 332 |
-
in_image = Image.open(image)
|
| 333 |
-
elif isinstance(image, Image.Image):
|
| 334 |
-
in_image = image
|
| 335 |
-
elif isinstance(image, np.ndarray):
|
| 336 |
-
in_image = Image.fromarray(image.astype(np.uint8))
|
| 337 |
-
else:
|
| 338 |
-
raise Exception('Unsupported input image format.')
|
| 339 |
-
|
| 340 |
-
# Visual Grounding
|
| 341 |
-
bboxes, labels, out_image = self.grounder(in_image, classes = ground_classes)
|
| 342 |
-
|
| 343 |
-
det_images = []
|
| 344 |
-
for bid, bbox in enumerate(bboxes):
|
| 345 |
-
crop_box = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
|
| 346 |
-
det_image = in_image.crop(crop_box)
|
| 347 |
-
det_image.save('temp/debug_bbox_image_{}.jpg'.format(bid))
|
| 348 |
-
det_images.append(det_image)
|
| 349 |
-
|
| 350 |
-
if len(det_images) == 0: # No object detected, use the full image.
|
| 351 |
-
det_images.append(in_image)
|
| 352 |
-
labels.append('image')
|
| 353 |
-
|
| 354 |
-
# Visual Captioning
|
| 355 |
-
captions = []
|
| 356 |
-
for det_image, label in zip(det_images, labels):
|
| 357 |
-
inp = get_caption_prompt(label, text)
|
| 358 |
-
caption = self.vlm(det_image, inp)
|
| 359 |
-
captions.append(caption)
|
| 360 |
-
|
| 361 |
-
for cid, caption in enumerate(captions):
|
| 362 |
-
with open('temp/caption_{}.txt'.format(cid), 'w', encoding='utf-8') as wf:
|
| 363 |
-
wf.write(caption)
|
| 364 |
-
|
| 365 |
-
# Visual Correlation
|
| 366 |
-
if len(captions) >= 2 and self.use_correlate:
|
| 367 |
-
queries = []
|
| 368 |
-
for mid, det_image in enumerate(det_images):
|
| 369 |
-
caption = captions[mid]
|
| 370 |
-
other_captions = []
|
| 371 |
-
for cid in range(len(captions)):
|
| 372 |
-
if cid == mid:
|
| 373 |
-
continue
|
| 374 |
-
other_captions.append(captions[cid])
|
| 375 |
-
inp = get_correlate_prompt(caption, other_captions)
|
| 376 |
-
query = self.vlm(det_image, inp)
|
| 377 |
-
queries.append(query)
|
| 378 |
-
else:
|
| 379 |
-
queries = captions
|
| 380 |
-
|
| 381 |
-
for qid, query in enumerate(queries):
|
| 382 |
-
with open('temp/query_{}.txt'.format(qid), 'w', encoding='utf-8') as wf:
|
| 383 |
-
wf.write(query)
|
| 384 |
-
|
| 385 |
-
queries = [text + " " + query for query in queries]
|
| 386 |
-
|
| 387 |
-
# Web Searching
|
| 388 |
-
contexts = self.searcher(queries)
|
| 389 |
-
|
| 390 |
-
# QA
|
| 391 |
-
TOKEN_LIMIT = 3500
|
| 392 |
-
max_length_per_context = TOKEN_LIMIT // len(contexts)
|
| 393 |
-
for cid, context in enumerate(contexts):
|
| 394 |
-
contexts[cid] = (queries[cid] + context)[:max_length_per_context]
|
| 395 |
-
|
| 396 |
-
inp = get_qa_prompt(text, contexts)
|
| 397 |
-
answer = self.vlm(in_image, inp)
|
| 398 |
-
|
| 399 |
-
with open('temp/answer.txt', 'w', encoding='utf-8') as wf:
|
| 400 |
-
wf.write(answer)
|
| 401 |
-
print(answer)
|
| 402 |
-
|
| 403 |
-
return answer
|
| 404 |
-
|
| 405 |
-
def app_run(
|
| 406 |
-
self,
|
| 407 |
-
image: Union[str, Image.Image, np.ndarray],
|
| 408 |
-
text: str,
|
| 409 |
-
ground_classes: List[str] = COCO_CLASSES
|
| 410 |
-
):
|
| 411 |
# Create and clear the temporary directory.
|
| 412 |
if not os.access('temp', os.F_OK):
|
| 413 |
os.makedirs('temp')
|
|
|
|
| 298 |
self.use_correlate = True
|
| 299 |
|
| 300 |
@spaces.GPU
|
| 301 |
+
def app_run(
|
| 302 |
self,
|
| 303 |
image: Union[str, Image.Image, np.ndarray],
|
| 304 |
text: str,
|
| 305 |
+
ground_classes: List[str] = COCO_CLASSES
|
| 306 |
+
):
|
| 307 |
self.searcher = WebSearcher(
|
| 308 |
model_path = self.search_model
|
| 309 |
)
|
|
|
|
| 318 |
load_8bit = self.vlm_load_8bit
|
| 319 |
)
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
# Create and clear the temporary directory.
|
| 322 |
if not os.access('temp', os.F_OK):
|
| 323 |
os.makedirs('temp')
|