Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
rerank model
Browse files- RAG/colpali.py +32 -33
- pages/Multimodal_Conversational_Search.py +1 -1
RAG/colpali.py
CHANGED
|
@@ -17,13 +17,12 @@ import matplotlib.pyplot as plt
|
|
| 17 |
import requests
|
| 18 |
import boto3
|
| 19 |
import streamlit as st
|
| 20 |
-
from IPython.display import display, Markdown
|
| 21 |
import base64
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
import torch
|
| 28 |
# from colpali_engine.models import ColPali, ColPaliProcessor
|
| 29 |
# from colpali_engine.utils.torch_utils import get_torch_device
|
|
@@ -286,37 +285,37 @@ def img_highlight(img,batch_queries,query_tokens):
|
|
| 286 |
print(f"Number of image patches: {n_patches}")
|
| 287 |
|
| 288 |
# # Generate the similarity maps
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
|
| 296 |
-
# #
|
| 297 |
-
|
| 298 |
|
| 299 |
-
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
|
| 318 |
-
|
| 319 |
-
return
|
| 320 |
|
| 321 |
|
| 322 |
|
|
|
|
| 17 |
import requests
|
| 18 |
import boto3
|
| 19 |
import streamlit as st
|
|
|
|
| 20 |
import base64
|
| 21 |
+
from colpali_engine.interpretability import (
|
| 22 |
+
get_similarity_maps_from_embeddings,
|
| 23 |
+
plot_all_similarity_maps,
|
| 24 |
+
plot_similarity_map,
|
| 25 |
+
)
|
| 26 |
import torch
|
| 27 |
# from colpali_engine.models import ColPali, ColPaliProcessor
|
| 28 |
# from colpali_engine.utils.torch_utils import get_torch_device
|
|
|
|
| 285 |
print(f"Number of image patches: {n_patches}")
|
| 286 |
|
| 287 |
# # Generate the similarity maps
|
| 288 |
+
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
| 289 |
+
image_embeddings=image_embeddings,
|
| 290 |
+
query_embeddings=query_embeddings,
|
| 291 |
+
n_patches=n_patches,
|
| 292 |
+
image_mask = image_mask
|
| 293 |
+
)
|
| 294 |
|
| 295 |
+
# # Get the similarity map for our (only) input image
|
| 296 |
+
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
| 297 |
|
| 298 |
+
query_tokens_from_model = query_tokens[0]['tokens']
|
| 299 |
|
| 300 |
+
plots = plot_all_similarity_maps(
|
| 301 |
+
image=image,
|
| 302 |
+
query_tokens=query_tokens_from_model,
|
| 303 |
+
similarity_maps=similarity_maps,
|
| 304 |
+
figsize=(8, 8),
|
| 305 |
+
show_colorbar=False,
|
| 306 |
+
add_title=True,
|
| 307 |
+
)
|
| 308 |
+
map_images = []
|
| 309 |
+
for idx, (fig, ax) in enumerate(plots):
|
| 310 |
+
if(idx<3):
|
| 311 |
+
continue
|
| 312 |
+
savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
|
| 313 |
+
fig.savefig(savepath, bbox_inches="tight")
|
| 314 |
+
map_images.append({'file':savepath})
|
| 315 |
+
print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
|
| 316 |
|
| 317 |
+
plt.close("all")
|
| 318 |
+
return map_images
|
| 319 |
|
| 320 |
|
| 321 |
|
pages/Multimodal_Conversational_Search.py
CHANGED
|
@@ -13,7 +13,7 @@ import botocore.session
|
|
| 13 |
import json
|
| 14 |
import random
|
| 15 |
import string
|
| 16 |
-
import rag_DocumentLoader
|
| 17 |
import rag_DocumentSearcher
|
| 18 |
import pandas as pd
|
| 19 |
from PIL import Image
|
|
|
|
| 13 |
import json
|
| 14 |
import random
|
| 15 |
import string
|
| 16 |
+
#import rag_DocumentLoader
|
| 17 |
import rag_DocumentSearcher
|
| 18 |
import pandas as pd
|
| 19 |
from PIL import Image
|