Spaces:
Runtime error
Runtime error
| from .utils import ( | |
| get_text_attributes, | |
| get_top_5_predictions, | |
| get_transformed_image, | |
| plotly_express_horizontal_bar_plot, | |
| translate_labels, | |
| bert_tokenizer | |
| ) | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import matplotlib.pyplot as plt | |
| from session import _get_state | |
| from .model.flax_clip_vision_bert.modeling_clip_vision_bert import ( | |
| FlaxCLIPVisionBertForMaskedLM, | |
| ) | |
| def softmax(logits): | |
| return np.exp(logits) / np.sum(np.exp(logits), axis=0) | |
| def app(): | |
| state = _get_state() | |
| def predict(transformed_image, caption_inputs): | |
| outputs = state.model(pixel_values=transformed_image, **caption_inputs) | |
| indices = np.where(caption_inputs['input_ids']==bert_tokenizer.mask_token_id) | |
| preds = outputs.logits[indices][0] | |
| sorted_indices = np.argsort(preds)[::-1] # Get reverse sorted scores | |
| top_5_indices = sorted_indices[:5] | |
| top_5_tokens = bert_tokenizer.convert_ids_to_tokens(top_5_indices) | |
| top_5_scores = np.array(preds[top_5_indices]) | |
| return top_5_tokens, top_5_scores | |
| def load_model(ckpt): | |
| return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt) | |
| mlm_checkpoints = ['flax-community/clip-vision-bert-cc12m-70k'] | |
| dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t") | |
| first_index = 20 | |
| # Init Session State | |
| if state.image_file is None: | |
| state.image_file = dummy_data.loc[first_index, "image_file"] | |
| caption = dummy_data.loc[first_index, "caption"].strip("- ") | |
| ids = bert_tokenizer(caption) | |
| ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id | |
| state.caption = bert_tokenizer.decode(ids) | |
| state.caption_lang_id = dummy_data.loc[first_index, "lang_id"] | |
| image_path = os.path.join("cc12m_data/images_vqa", state.image_file) | |
| image = plt.imread(image_path) | |
| state.image = image | |
| if state.model is None: | |
| # Display Top-5 Predictions | |
| with st.spinner("Loading model..."): | |
| state.model = load_model(mlm_checkpoints[0]) | |
| if st.button( | |
| "Get a random example", | |
| help="Get a random example from the 100 `seeded` image-text pairs.", | |
| ): | |
| sample = dummy_data.sample(1).reset_index() | |
| state.image_file = sample.loc[0, "image_file"] | |
| caption = sample.loc[0, "caption"].strip("- ") | |
| ids = bert_tokenizer(caption) | |
| ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id | |
| state.caption = bert_tokenizer.decode(ids) | |
| state.caption_lang_id = sample.loc[0, "lang_id"] | |
| image_path = os.path.join("cc12m_data/images_vqa", state.image_file) | |
| image = plt.imread(image_path) | |
| state.image = image | |
| transformed_image = get_transformed_image(state.image) | |
| new_col1, new_col2 = st.beta_columns([5, 5]) | |
| # Display Image | |
| new_col1.image(state.image, use_column_width="always") | |
| # Display caption | |
| new_col2.write("Write your text with exactly one [MASK] token.") | |
| caption = new_col2.text_input( | |
| label="Text", | |
| value=state.caption, | |
| help="Type your masked caption regarding the image above in one of the four languages.", | |
| ) | |
| caption_inputs = get_text_attributes(caption) | |
| # Display Top-5 Predictions | |
| with st.spinner("Predicting..."): | |
| logits = predict(transformed_image, dict(caption_inputs)) | |
| logits = softmax(logits) | |
| labels, values = get_top_5_predictions(logits) | |
| fig = plotly_express_horizontal_bar_plot(values, labels) | |
| st.plotly_chart(fig, use_container_width=True) |