Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import numpy as np | |
| import streamlit as st | |
| from PIL import Image | |
| from transformers import CLIPProcessor, FlaxCLIPModel | |
| import nmslib | |
| def load_index(image_vector_file): | |
| filenames, image_vecs = [], [] | |
| fvec = open(image_vector_file, "r") | |
| for line in fvec: | |
| cols = line.strip().split(' ') | |
| filename = cols[0] | |
| image_vec = np.array([float(x) for x in cols[1].split(',')]) | |
| filenames.append(filename) | |
| image_vecs.append(image_vec) | |
| V = np.array(image_vecs) | |
| index = nmslib.init(method='hnsw', space='cosinesimil') | |
| index.addDataPointBatch(V) | |
| index.createIndex({'post': 2}, print_progress=True) | |
| return filenames, index | |
| def load_captions(caption_file): | |
| image2caption = {} | |
| with open(caption_file, "r") as fcap: | |
| for line in fcap: | |
| data = json.loads(line.strip()) | |
| filename = data["filename"] | |
| captions = data["captions"] | |
| image2caption[filename] = captions | |
| return image2caption | |
| def get_image(text, number): | |
| model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| filename, index = load_index("./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv") | |
| image2caption = load_captions("./images/test-captions.json") | |
| inputs = processor(text=[text], images=None, return_tensors="jax", padding=True) | |
| vector = model.get_text_features(**inputs) | |
| vector = np.asarray(vector) | |
| ids, distances = index.knnQuery(vector, k=number) | |
| result_filenames = [filename[index] for index in ids] | |
| for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)): | |
| caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score) | |
| col1, col2, col3 = st.columns([2, 10, 10]) | |
| col1.markdown("{:d}.".format(rank + 1)) | |
| col2.image(Image.open(os.path.join("./images", result_filename)), | |
| caption=caption) | |
| # caption_text = [] | |
| # for caption in image2caption[result_filename]: | |
| # caption_text.append("* {:s}".format(caption)) | |
| # col3.markdown("".join(caption_text)) | |
| st.markdown("---") | |
| suggest_idx = -1 | |
| def app(): | |
| st.title("Welcome to Space Vector") | |
| st.text("You want search an image with given text.") | |
| text = st.text_input("Enter text: ") | |
| number = st.number_input("Enter number of images result: ", min_value=1, max_value=10) | |
| if st.button("Search"): | |
| get_image(text, number) | |