Spaces:
Build error
Build error
Sujit Pal
commited on
Commit
·
f58917e
1
Parent(s):
5de821f
fix: changing output format to include caption
Browse files- dashboard_image2image.py +14 -13
- dashboard_text2image.py +14 -16
- utils.py +15 -0
dashboard_image2image.py
CHANGED
|
@@ -12,11 +12,9 @@ import utils
|
|
| 12 |
|
| 13 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
| 14 |
MODEL_PATH = "flax-community/clip-rsicd-v2"
|
| 15 |
-
|
| 16 |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
| 17 |
-
|
| 18 |
IMAGES_DIR = "./images"
|
| 19 |
-
|
| 20 |
|
| 21 |
@st.cache(allow_output_mutation=True)
|
| 22 |
def load_example_images():
|
|
@@ -62,6 +60,7 @@ def download_and_prepare_image(image_url):
|
|
| 62 |
def app():
|
| 63 |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
| 64 |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
|
|
|
| 65 |
|
| 66 |
example_image_list = load_example_images()
|
| 67 |
|
|
@@ -150,17 +149,19 @@ def app():
|
|
| 150 |
query_vec = np.asarray(query_vec)
|
| 151 |
ids, distances = index.knnQuery(query_vec, k=11)
|
| 152 |
result_filenames = [filenames[id] for id in ids]
|
| 153 |
-
|
| 154 |
for result_filename, score in zip(result_filenames, distances):
|
| 155 |
if image_name is not None and result_filename == image_name:
|
| 156 |
continue
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
suggest_idx = -1
|
|
|
|
| 12 |
|
| 13 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
| 14 |
MODEL_PATH = "flax-community/clip-rsicd-v2"
|
|
|
|
| 15 |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
|
|
|
| 16 |
IMAGES_DIR = "./images"
|
| 17 |
+
CAPTIONS_FILE = os.path.join(IMAGES_DIR, "dataset_rsicd.json")
|
| 18 |
|
| 19 |
@st.cache(allow_output_mutation=True)
|
| 20 |
def load_example_images():
|
|
|
|
| 60 |
def app():
|
| 61 |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
| 62 |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
| 63 |
+
image2caption = utils.load_captions(CAPTIONS_FILE)
|
| 64 |
|
| 65 |
example_image_list = load_example_images()
|
| 66 |
|
|
|
|
| 149 |
query_vec = np.asarray(query_vec)
|
| 150 |
ids, distances = index.knnQuery(query_vec, k=11)
|
| 151 |
result_filenames = [filenames[id] for id in ids]
|
| 152 |
+
rank = 0
|
| 153 |
for result_filename, score in zip(result_filenames, distances):
|
| 154 |
if image_name is not None and result_filename == image_name:
|
| 155 |
continue
|
| 156 |
+
caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
|
| 157 |
+
col1, col2, col3 = st.beta_columns([2, 10, 10])
|
| 158 |
+
col1.markdown("{:d}.".format(rank + 1))
|
| 159 |
+
col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
|
| 160 |
+
caption=caption)
|
| 161 |
+
caption_text = []
|
| 162 |
+
for caption in image2caption[result_filename]:
|
| 163 |
+
caption_text.append("* {:s}\n".format(caption))
|
| 164 |
+
col3.markdown("".join(caption_text))
|
| 165 |
+
rank += 1
|
| 166 |
+
st.markdown("---")
|
| 167 |
suggest_idx = -1
|
dashboard_text2image.py
CHANGED
|
@@ -4,25 +4,21 @@ import numpy as np
|
|
| 4 |
import os
|
| 5 |
import streamlit as st
|
| 6 |
|
|
|
|
| 7 |
from transformers import CLIPProcessor, FlaxCLIPModel
|
| 8 |
|
| 9 |
import utils
|
| 10 |
|
| 11 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
| 12 |
-
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
|
| 13 |
MODEL_PATH = "flax-community/clip-rsicd-v2"
|
| 14 |
-
|
| 15 |
-
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
|
| 16 |
-
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
| 17 |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
| 18 |
-
|
| 19 |
-
# IMAGES_DIR = "/home/shared/data/rsicd_images"
|
| 20 |
IMAGES_DIR = "./images"
|
| 21 |
-
|
| 22 |
|
| 23 |
def app():
|
| 24 |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
| 25 |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
|
|
|
| 26 |
|
| 27 |
st.title("Retrieve Images given Text")
|
| 28 |
st.markdown("""
|
|
@@ -78,13 +74,15 @@ def app():
|
|
| 78 |
query_vec = np.asarray(query_vec)
|
| 79 |
ids, distances = index.knnQuery(query_vec, k=10)
|
| 80 |
result_filenames = [filenames[id] for id in ids]
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
suggest_idx = -1
|
|
|
|
| 4 |
import os
|
| 5 |
import streamlit as st
|
| 6 |
|
| 7 |
+
from PIL import Image
|
| 8 |
from transformers import CLIPProcessor, FlaxCLIPModel
|
| 9 |
|
| 10 |
import utils
|
| 11 |
|
| 12 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
|
|
|
| 13 |
MODEL_PATH = "flax-community/clip-rsicd-v2"
|
|
|
|
|
|
|
|
|
|
| 14 |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
|
|
|
|
|
|
| 15 |
IMAGES_DIR = "./images"
|
| 16 |
+
CAPTIONS_FILE = os.path.join(IMAGES_DIR, "dataset_rsicd.json")
|
| 17 |
|
| 18 |
def app():
|
| 19 |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
| 20 |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
| 21 |
+
image2caption = utils.load_captions(CAPTIONS_FILE)
|
| 22 |
|
| 23 |
st.title("Retrieve Images given Text")
|
| 24 |
st.markdown("""
|
|
|
|
| 74 |
query_vec = np.asarray(query_vec)
|
| 75 |
ids, distances = index.knnQuery(query_vec, k=10)
|
| 76 |
result_filenames = [filenames[id] for id in ids]
|
| 77 |
+
for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
|
| 78 |
+
caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
|
| 79 |
+
col1, col2, col3 = st.beta_columns([2, 10, 10])
|
| 80 |
+
col1.markdown("{:d}.".format(rank + 1))
|
| 81 |
+
col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
|
| 82 |
+
caption=caption)
|
| 83 |
+
caption_text = []
|
| 84 |
+
for caption in image2caption[result_filename]:
|
| 85 |
+
caption_text.append("* {:s}\n".format(caption))
|
| 86 |
+
col3.markdown("".join(caption_text))
|
| 87 |
+
st.markdown("---")
|
| 88 |
suggest_idx = -1
|
utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import matplotlib.pyplot as plt
|
| 2 |
import nmslib
|
| 3 |
import numpy as np
|
|
@@ -31,3 +32,17 @@ def load_model(model_path, baseline_model):
|
|
| 31 |
# processor = CLIPProcessor.from_pretrained(baseline_model)
|
| 32 |
processor = CLIPProcessor.from_pretrained(model_path)
|
| 33 |
return model, processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
import nmslib
|
| 4 |
import numpy as np
|
|
|
|
| 32 |
# processor = CLIPProcessor.from_pretrained(baseline_model)
|
| 33 |
processor = CLIPProcessor.from_pretrained(model_path)
|
| 34 |
return model, processor
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@st.cache(allow_output_mutation=True)
|
| 38 |
+
def load_captions(caption_file):
|
| 39 |
+
image2caption = {}
|
| 40 |
+
with open(caption_file, "r") as fcap:
|
| 41 |
+
data = json.loads(fcap.read())
|
| 42 |
+
for image in data["images"]:
|
| 43 |
+
filename = image["filename"]
|
| 44 |
+
captions = []
|
| 45 |
+
for sentence in image["sentences"]:
|
| 46 |
+
captions.append(sentence["raw"])
|
| 47 |
+
image2caption[filename] = captions
|
| 48 |
+
return image2caption
|