Spaces:
Runtime error
Runtime error
| from io import BytesIO | |
| import streamlit as st | |
| import pandas as pd | |
| import json | |
| import os | |
| import numpy as np | |
| from model.flax_clip_vision_bert.modeling_clip_vision_bert import FlaxCLIPVisionBertForSequenceClassification | |
| from utils import get_transformed_image, get_text_attributes, get_top_5_predictions, plotly_express_horizontal_bar_plot, translate_labels | |
| import matplotlib.pyplot as plt | |
| from mtranslate import translate | |
| from PIL import Image | |
| def load_model(ckpt): | |
| return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) | |
| def softmax(logits): | |
| return np.exp(logits)/np.sum(np.exp(logits), axis=0) | |
| checkpoints = ['./ckpt/ckpt-60k-5999'] # TODO: Maybe add more checkpoints? | |
| dummy_data = pd.read_csv('dummy_vqa_multilingual.tsv', sep='\t') | |
| with open('answer_reverse_mapping.json') as f: | |
| answer_reverse_mapping = json.load(f) | |
| st.set_page_config( | |
| page_title="Multilingual VQA", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| page_icon="./misc/mvqa-logo.png", | |
| ) | |
| st.title("Multilingual Visual Question Answering") | |
| with st.beta_expander("About"): | |
| pass | |
| with st.beta_expander("Method"): | |
| st.image("./misc/Multilingual-VQA.png") | |
| with st.beta_expander("Results"): | |
| pass | |
| # Init Session State | |
| if 'image_file' not in st.session_state: | |
| st.session_state.image_file = dummy_data.loc[0,'image_file'] | |
| st.session_state.question = dummy_data.loc[0,'question'] | |
| st.session_state.answer_label = dummy_data.loc[0,'answer_label'] | |
| st.session_state.question_lang_id = dummy_data.loc[0, 'lang_id'] | |
| st.session_state.answer_lang_id = dummy_data.loc[0, 'lang_id'] | |
| image_path = os.path.join('images',st.session_state.image_file) | |
| image = plt.imread(image_path) | |
| st.session_state.image = image | |
| col1, col2 = st.beta_columns([5,5]) | |
| if col1.button('Get a Random Example'): | |
| sample = dummy_data.sample(1).reset_index() | |
| st.session_state.image_file = sample.loc[0,'image_file'] | |
| st.session_state.question = sample.loc[0,'question'] | |
| st.session_state.answer_label = sample.loc[0,'answer_label'] | |
| st.session_state.question_lang_id = sample.loc[0, 'lang_id'] | |
| st.session_state.answer_lang_id = sample.loc[0, 'lang_id'] | |
| image_path = os.path.join('images',st.session_state.image_file) | |
| image = plt.imread(image_path) | |
| st.session_state.image = image | |
| uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg']) | |
| if uploaded_file is not None: | |
| st.session_state.image_file = os.path.join('images/val2014',uploaded_file.name) | |
| st.session_state.image = np.array(Image.open(uploaded_file)) | |
| transformed_image = get_transformed_image(st.session_state.image) | |
| # Display Image | |
| st.image(st.session_state.image, use_column_width='always') | |
| # Display Question | |
| question = st.text_input(label="Question", value=st.session_state.question) | |
| st.markdown(f"""**English Translation**: {question if st.session_state.question_lang_id == "en" else translate(question, 'en')}""") | |
| question_inputs = get_text_attributes(question) | |
| # Select Language | |
| options = ['en', 'de', 'es', 'fr'] | |
| st.session_state.answer_lang_id = st.selectbox('Answer Language', index=options.index(st.session_state.answer_lang_id), options=options) | |
| # Display Top-5 Predictions | |
| with st.spinner('Loading model...'): | |
| model = load_model(checkpoints[0]) | |
| with st.spinner('Predicting...'): | |
| predictions = model(pixel_values = transformed_image, **question_inputs) | |
| logits = np.array(predictions[0][0]) | |
| logits = softmax(logits) | |
| labels, values = get_top_5_predictions(logits, answer_reverse_mapping) | |
| translated_labels = translate_labels(labels, st.session_state.answer_lang_id) | |
| fig = plotly_express_horizontal_bar_plot(values, translated_labels) | |
| st.plotly_chart(fig) |