Spaces:
Build error
Build error
| import jax | |
| import flax | |
| import matplotlib.pyplot as plt | |
| import nmslib | |
| import numpy as np | |
| import os | |
| import streamlit as st | |
| from tempfile import NamedTemporaryFile | |
| from torchvision.transforms import Compose, Resize, ToPILImage | |
| from transformers import CLIPProcessor, FlaxCLIPModel | |
| from PIL import Image | |
| BASELINE_MODEL = "openai/clip-vit-base-patch32" | |
| # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1" | |
| MODEL_PATH = "flax-community/clip-rsicd-v2" | |
| # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv" | |
| # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
| IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
| # IMAGES_DIR = "/home/shared/data/rsicd_images" | |
| IMAGES_DIR = "./images" | |
| 2 | |
| # @st.cache(allow_output_mutation=True) | |
| # def load_index(): | |
| # filenames, image_vecs = [], [] | |
| # fvec = open(IMAGE_VECTOR_FILE, "r") | |
| # for line in fvec: | |
| # cols = line.strip().split('\t') | |
| # filename = cols[0] | |
| # image_vec = np.array([float(x) for x in cols[1].split(',')]) | |
| # filenames.append(filename) | |
| # image_vecs.append(image_vec) | |
| # V = np.array(image_vecs) | |
| # index = nmslib.init(method='hnsw', space='cosinesimil') | |
| # index.addDataPointBatch(V) | |
| # index.createIndex({'post': 2}, print_progress=True) | |
| # return filenames, index | |
| def load_model(): | |
| # model = FlaxCLIPModel.from_pretrained(MODEL_PATH) | |
| # processor = CLIPProcessor.from_pretrained(BASELINE_MODEL) | |
| model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2") | |
| processor = CLIPProcessor.from_pretrained("flax-community/clip-rsicd-v2") | |
| return model, processor | |
| def split_image(X): | |
| num_rows = X.shape[0] // 224 | |
| num_cols = X.shape[1] // 224 | |
| Xc = X[0 : num_rows * 224, 0 : num_cols * 224, :] | |
| patches = [] | |
| for j in range(num_rows): | |
| for i in range(num_cols): | |
| patches.append(Xc[j * 224 : (j + 1) * 224, | |
| i * 224 : (i + 1) * 224, | |
| :]) | |
| return num_rows, num_cols, patches | |
| def get_patch_probabilities(patches, searched_feature, | |
| image_preprocesor, | |
| model, processor): | |
| images = [image_preprocesor(patch) for patch in patches] | |
| text = "An aerial image of {:s}".format(searched_feature) | |
| inputs = processor(images=images, | |
| text=text, | |
| return_tensors="jax", | |
| padding=True) | |
| outputs = model(**inputs) | |
| probs = jax.nn.softmax(outputs.logits_per_text, axis=-1) | |
| probs_np = np.asarray(probs)[0] | |
| return probs_np | |
| def get_image_ranks(probs): | |
| temp = np.argsort(-probs) | |
| ranks = np.empty_like(temp) | |
| ranks[temp] = np.arange(len(probs)) | |
| return ranks | |
| def app(): | |
| model, processor = load_model() | |
| st.title("Find Features in Images") | |
| st.markdown(""" | |
| The CLIP model from OpenAI is trained in a self-supervised manner using | |
| contrastive learning to project images and caption text onto a common | |
| embedding space. We have fine-tuned the model using the RSICD dataset | |
| (10k images and ~50k captions from the remote sensing domain). | |
| This demo shows the ability of the model to find specific features | |
| (specified as text queries) in the image. As an example, say you wish to | |
| find the parts of the following image that contain a `beach`, `houses`, | |
| or `ships`. We partition the image into tiles of (224, 224) and report | |
| how likely each of them are to contain each text features. | |
| """) | |
| st.image("demo-images/st_tropez_1.png") | |
| st.image("demo-images/st_tropez_2.png") | |
| st.markdown(""" | |
| For this image and the queries listed above, our model reports that the | |
| two left tiles are most likely to contain a `beach`, the two top right | |
| tiles are most likely to contain `houses`, and the two bottom right tiles | |
| are likely to contain `boats`. | |
| You can try it yourself with your own photographs. | |
| [Unsplash](https://unsplash.com/s/photos/aerial-view) has some good | |
| aerial photographs. You will need to download from Unsplash to your | |
| computer and upload it to the demo app. | |
| """) | |
| with st.form(key="form_3"): | |
| buf = st.file_uploader("Upload Image for Analysis") | |
| searched_feature = st.text_input(label="Feature to find") | |
| submit_button = st.form_submit_button("Find") | |
| if submit_button: | |
| ftmp = NamedTemporaryFile() | |
| ftmp.write(buf.getvalue()) | |
| image = plt.imread(ftmp.name) | |
| if len(image.shape) != 3 and image.shape[2] != 3: | |
| st.error("Image should be an RGB image") | |
| if image.shape[0] < 224 or image.shape[1] < 224: | |
| st.error("Image should be at least (224 x 224") | |
| st.image(image, caption="Input Image") | |
| st.markdown("---") | |
| num_rows, num_cols, patches = split_image(image) | |
| image_preprocessor = Compose([ | |
| ToPILImage(), | |
| Resize(224) | |
| ]) | |
| num_rows, num_cols, patches = split_image(image) | |
| patch_probs = get_patch_probabilities( | |
| patches, | |
| searched_feature, | |
| image_preprocessor, | |
| model, | |
| processor) | |
| patch_ranks = get_image_ranks(patch_probs) | |
| for i in range(num_rows): | |
| row_patches = patches[i * num_cols : (i + 1) * num_cols] | |
| row_probs = patch_probs[i * num_cols : (i + 1) * num_cols] | |
| row_ranks = patch_ranks[i * num_cols : (i + 1) * num_cols] | |
| captions = ["p({:s})={:.3f}, rank={:d}".format(searched_feature, p, r + 1) | |
| for p, r in zip(row_probs, row_ranks)] | |
| st.image(row_patches, caption=captions) | |