Spaces:
Runtime error
Runtime error
| import nmslib | |
| import numpy as np | |
| import streamlit as st | |
| from transformers import AutoTokenizer, CLIPProcessor | |
| from model import FlaxHybridCLIP | |
| from PIL import Image | |
| import jax.numpy as jnp | |
| import os | |
| import jax | |
| # st.header('Under construction') | |
| st.sidebar.title("CLIP React Demo") | |
| st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)") | |
| sc= st.sidebar.columns(2) | |
| sc[0].image("./huggingface_explode3.png",width=150) | |
| sc[1].write(" ") | |
| sc[1].write(" ") | |
| sc[1].markdown("## Researching fun") | |
| with st.sidebar.expander("Motivation",expanded=True): | |
| st.markdown( | |
| """ | |
| Reaction GIFs became an integral part of communication. | |
| They convey complex emotions with many levels, in a short compact format. | |
| If a picture is worth a thousand words then a GIF is worth more. | |
| A lot of people would agree it is not always easy to find the perfect reaction GIF. | |
| This is just a first step in the more ambitious goal of GIF/Image generation. | |
| """ | |
| ) | |
| top_k=st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20) | |
| col_count=4 | |
| file_names=os.listdir("./jpg") | |
| file_names.sort() | |
| show_val=st.sidebar.button("show all validation set images") | |
| if show_val: | |
| cols=st.sidebar.columns(col_count) | |
| for i,im in enumerate(file_names): | |
| j=i%col_count | |
| cols[j].image("./jpg/"+im) | |
| st.write("# Search Reaction GIFs with CLIP ") | |
| st.write(" ") | |
| st.write(" ") | |
| def load_model(): | |
| model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base") | |
| return model, processor | |
| def load_image_index(): | |
| index = nmslib.init(method='hnsw', space='cosinesimil') | |
| index.loadIndex("./features/image_embeddings", load_data=True) | |
| return index | |
| image_index = load_image_index() | |
| model, processor = load_model() | |
| # TODO | |
| def add_image_emb(image): | |
| image = Image.open(image).convert("RGB") | |
| inputs = processor(text=[""], images=image, return_tensors="jax", padding=True) | |
| inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) | |
| features = model(**inputs).image_embeds | |
| image_index.addDataPoint(features) | |
| def query_with_images(query_images,query_text): | |
| images=[] | |
| for im in query_images: | |
| img=Image.open(im).convert("RGB") | |
| if im.name.endswith(".gif"): | |
| img.seek(0) | |
| images.append(img) | |
| inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True) | |
| inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image.reshape(-1) | |
| # st.write(logits_per_image) | |
| probs = jax.nn.softmax(logits_per_image) | |
| # st.write(probs) | |
| # st.write(list(zip(images,probs))) | |
| results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True) | |
| # st.write(results) | |
| return zip(*results) | |
| q_cols=st.columns([5,2,5]) | |
| examples = ["OMG that is disgusting","I'm so scared right now"," I got the job 🎉","Congratulations to all the flax-community week teams","You're awesome","I love you ❤️"] | |
| example_input = q_cols[0].radio("Example Queries :",examples,index=4,help="These are examples I wrote off the top of my head. They don't occur in the dataset") | |
| q_cols[2].markdown( | |
| """ | |
| Searches among the validation set images if not specified | |
| (There may be non-exact duplicates) | |
| """ | |
| ) | |
| query_text = q_cols[0].text_input("Write text you want to get reaction for", value=example_input) | |
| query_images = q_cols[2].file_uploader("(optional) Upload images to rank them",type=['jpg','jpeg','gif'], accept_multiple_files=True) | |
| if query_images: | |
| st.write("Ranking your uploaded images with respect to input text:") | |
| with st.spinner("Calculating..."): | |
| ids, dists = query_with_images(query_images,query_text) | |
| else: | |
| st.write("Found these images within validation set:") | |
| with st.spinner("Calculating..."): | |
| proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True) | |
| vec = np.asarray(model.get_text_features(**proc)) | |
| ids, dists = image_index.knnQuery(vec, k=top_k) | |
| show_gif=st.checkbox("Play GIFs",value=True,help="Will play the original animation. Only first frame is used in training!") | |
| ext = "jpg" if not show_gif else "gif" | |
| res_cols=st.columns(col_count) | |
| for i,(id_, dist) in enumerate(zip(ids, dists)): | |
| j=i%col_count | |
| with res_cols[j]: | |
| if isinstance(id_, np.int32): | |
| st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}") | |
| # st.write(file_names[id_]) | |
| st.write(1.0 - dist, help="score") | |
| else: | |
| st.image(id_) | |
| st.write(dist, help="score") | |
| # Credits | |
| st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)") | |