Spaces:
Build error
Build error
| import matplotlib.pyplot as plt | |
| import nmslib | |
| import numpy as np | |
| import os | |
| import requests | |
| import streamlit as st | |
| from PIL import Image | |
| from transformers import CLIPProcessor, FlaxCLIPModel | |
| import utils | |
| BASELINE_MODEL = "openai/clip-vit-base-patch32" | |
| MODEL_PATH = "flax-community/clip-rsicd-v2" | |
| IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
| IMAGES_DIR = "./images" | |
| CAPTIONS_FILE = os.path.join(IMAGES_DIR, "test-captions.json") | |
| def load_example_images(): | |
| example_images = {} | |
| image_names = os.listdir(IMAGES_DIR) | |
| for image_name in image_names: | |
| if image_name.find("_") < 0: | |
| continue | |
| image_class = image_name.split("_")[0] | |
| if image_class in example_images.keys(): | |
| example_images[image_class].append(image_name) | |
| else: | |
| example_images[image_class] = [image_name] | |
| example_image_list = sorted([v[np.random.randint(0, len(v))] | |
| for k, v in example_images.items()][0:10]) | |
| return example_image_list | |
| def get_image_thumbnail(image_filename): | |
| image = Image.open(os.path.join(IMAGES_DIR, image_filename)) | |
| image = image.resize((100, 100)) | |
| return image | |
| def download_and_prepare_image(image_url): | |
| try: | |
| image_raw = requests.get(image_url, stream=True,).raw | |
| image = Image.open(image_raw).convert("RGB") | |
| width, height = image.size | |
| resize_mult = width / 224 if width < height else height / 224 | |
| image = image.resize((int(width // resize_mult), | |
| int(height // resize_mult))) | |
| width, height = image.size | |
| left = int((width - 224) // 2) | |
| top = int((height - 224) // 2) | |
| right = int((width + 224) // 2) | |
| bottom = int((height + 224) // 2) | |
| image = image.crop((left, top, right, bottom)) | |
| return image | |
| except Exception as e: | |
| return None | |
| def app(): | |
| filenames, index = utils.load_index(IMAGE_VECTOR_FILE) | |
| model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL) | |
| image2caption = utils.load_captions(CAPTIONS_FILE) | |
| example_image_list = load_example_images() | |
| st.title("Retrieve Images given Images") | |
| st.markdown(""" | |
| This demo shows the image to image retrieval capabilities of this model, i.e., | |
| given an image file name as a query, we use our fine-tuned CLIP model | |
| to project the query image to the image/caption embedding space and search | |
| for nearby images (by cosine similarity) in this space. | |
| Our fine-tuned CLIP model was previously used to generate image vectors for | |
| our demo, and NMSLib was used for fast vector access. | |
| Here are some randomly generated image files from our corpus, that you can | |
| find similar images for by selecting the button below it. Alternatively you | |
| can upload your own image from the Internet. | |
| """) | |
| suggest_idx = -1 | |
| col0, col1, col2, col3, col4 = st.columns(5) | |
| col0.image(get_image_thumbnail(example_image_list[0])) | |
| col1.image(get_image_thumbnail(example_image_list[1])) | |
| col2.image(get_image_thumbnail(example_image_list[2])) | |
| col3.image(get_image_thumbnail(example_image_list[3])) | |
| col4.image(get_image_thumbnail(example_image_list[4])) | |
| col0t, col1t, col2t, col3t, col4t = st.columns(5) | |
| with col0t: | |
| if st.button("Image-1"): | |
| suggest_idx = 0 | |
| with col1t: | |
| if st.button("Image-2"): | |
| suggest_idx = 1 | |
| with col2t: | |
| if st.button("Image-3"): | |
| suggest_idx = 2 | |
| with col3t: | |
| if st.button("Image-4"): | |
| suggest_idx = 3 | |
| with col4t: | |
| if st.button("Image-5"): | |
| suggest_idx = 4 | |
| col5, col6, col7, col8, col9 = st.columns(5) | |
| col5.image(get_image_thumbnail(example_image_list[5])) | |
| col6.image(get_image_thumbnail(example_image_list[6])) | |
| col7.image(get_image_thumbnail(example_image_list[7])) | |
| col8.image(get_image_thumbnail(example_image_list[8])) | |
| col9.image(get_image_thumbnail(example_image_list[9])) | |
| col5t, col6t, col7t, col8t, col9t = st.columns(5) | |
| with col5t: | |
| if st.button("Image-6"): | |
| suggest_idx = 5 | |
| with col6t: | |
| if st.button("Image-7"): | |
| suggest_idx = 6 | |
| with col7t: | |
| if st.button("Image-8"): | |
| suggest_idx = 7 | |
| with col8t: | |
| if st.button("Image-9"): | |
| suggest_idx = 8 | |
| with col9t: | |
| if st.button("Image-10"): | |
| suggest_idx = 9 | |
| image_url = st.text_input( | |
| "OR provide an image URL", | |
| value="https://static.eos.com/wp-content/uploads/2019/04/Main.jpg") | |
| submit_button = st.button("Find Similar") | |
| if submit_button or suggest_idx > -1: | |
| image_name = None | |
| if suggest_idx > -1: | |
| image_name = example_image_list[suggest_idx] | |
| image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name))) | |
| else: | |
| image = download_and_prepare_image(image_url) | |
| st.image(image, caption="Input Image") | |
| st.markdown("---") | |
| if image is None: | |
| st.error("Image could not be downloaded, please try another one!") | |
| else: | |
| inputs = processor(images=image, return_tensors="jax", padding=True) | |
| query_vec = model.get_image_features(**inputs) | |
| query_vec = np.asarray(query_vec) | |
| ids, distances = index.knnQuery(query_vec, k=11) | |
| result_filenames = [filenames[id] for id in ids] | |
| rank = 0 | |
| for result_filename, score in zip(result_filenames, distances): | |
| if image_name is not None and result_filename == image_name: | |
| continue | |
| caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score) | |
| col1, col2, col3 = st.columns([2, 10, 10]) | |
| col1.markdown("{:d}.".format(rank + 1)) | |
| col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)), | |
| caption=caption) | |
| caption_text = [] | |
| for caption in image2caption[result_filename]: | |
| caption_text.append("* {:s}\n".format(caption)) | |
| col3.markdown("".join(caption_text)) | |
| rank += 1 | |
| st.markdown("---") | |
| suggest_idx = -1 | |