Spaces:
Runtime error
Runtime error
| from .utils import ( | |
| get_text_attributes, | |
| get_top_5_predictions, | |
| get_transformed_image, | |
| plotly_express_horizontal_bar_plot, | |
| translate_labels, | |
| ) | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import matplotlib.pyplot as plt | |
| import json | |
| from mtranslate import translate | |
| from .utils import read_markdown | |
| from .model.flax_clip_vision_bert.modeling_clip_vision_bert import ( | |
| FlaxCLIPVisionBertForSequenceClassification, | |
| ) | |
| def softmax(logits): | |
| return np.exp(logits) / np.sum(np.exp(logits), axis=0) | |
| def app(state): | |
| vqa_state = state | |
| with st.beta_expander("Usage"): | |
| st.write(read_markdown("vqa_usage.md")) | |
| st.write(read_markdown("vqa_intro.md")) | |
| # @st.cache(persist=False) | |
| def predict(transformed_image, question_inputs): | |
| return np.array( | |
| vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0] | |
| ) | |
| # @st.cache(persist=False) | |
| def load_model(ckpt): | |
| return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) | |
| vqa_checkpoints = [ | |
| "flax-community/clip-vision-bert-vqa-ft-6k" | |
| ] # TODO: Maybe add more checkpoints? | |
| dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t") | |
| code_to_name = { | |
| "en": "English", | |
| "fr": "French", | |
| "de": "German", | |
| "es": "Spanish", | |
| } | |
| with open("answer_reverse_mapping.json") as f: | |
| answer_reverse_mapping = json.load(f) | |
| first_index = 20 | |
| # Init Session vqa_state | |
| if vqa_state.vqa_image_file is None: | |
| vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"] | |
| vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ") | |
| vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"] | |
| vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"] | |
| vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"] | |
| image_path = os.path.join("resized_images", vqa_state.vqa_image_file) | |
| image = plt.imread(image_path) | |
| vqa_state.vqa_image = image | |
| if vqa_state.vqa_model is None: | |
| with st.spinner("Loading model..."): | |
| vqa_state.vqa_model = load_model(vqa_checkpoints[0]) | |
| # Display Top-5 Predictions | |
| if st.button( | |
| "Get a random example", | |
| help="Get a random example from the 100 `seeded` image-text pairs.", | |
| ): | |
| sample = dummy_data.sample(1).reset_index() | |
| vqa_state.vqa_image_file = sample.loc[0, "image_file"] | |
| vqa_state.question = sample.loc[0, "question"].strip("- ") | |
| vqa_state.answer_label = sample.loc[0, "answer_label"] | |
| vqa_state.question_lang_id = sample.loc[0, "lang_id"] | |
| vqa_state.answer_lang_id = sample.loc[0, "lang_id"] | |
| image_path = os.path.join("resized_images", vqa_state.vqa_image_file) | |
| image = plt.imread(image_path) | |
| vqa_state.vqa_image = image | |
| transformed_image = get_transformed_image(vqa_state.vqa_image) | |
| new_col1, new_col2 = st.beta_columns([5, 5]) | |
| # Display Image | |
| new_col1.image(vqa_state.vqa_image, use_column_width="always") | |
| # Display Question | |
| question = new_col2.text_input( | |
| label="Question", | |
| value=vqa_state.question, | |
| help="Type your question regarding the image above in one of the four languages.", | |
| ) | |
| new_col2.markdown( | |
| f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}""" | |
| ) | |
| question_inputs = get_text_attributes(question) | |
| # Select Language | |
| options = ["en", "de", "es", "fr"] | |
| vqa_state.answer_lang_id = new_col2.selectbox( | |
| "Answer Language", | |
| index=options.index(vqa_state.answer_lang_id), | |
| options=options, | |
| format_func=lambda x: code_to_name[x], | |
| help="The language to be used to show the top-5 labels.", | |
| ) | |
| actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)] | |
| new_col2.markdown( | |
| "**Actual Answer**: " | |
| + translate_labels([actual_answer], vqa_state.answer_lang_id)[0] | |
| + " (" | |
| + actual_answer | |
| + ")" | |
| ) | |
| with st.spinner("Predicting..."): | |
| logits = predict(transformed_image, dict(question_inputs)) | |
| logits = softmax(logits) | |
| labels, values = get_top_5_predictions(logits, answer_reverse_mapping) | |
| translated_labels = translate_labels(labels, vqa_state.answer_lang_id) | |
| fig = plotly_express_horizontal_bar_plot(values, translated_labels) | |
| st.plotly_chart(fig, use_container_width=True) | |