Spaces:
Runtime error
Runtime error
Commit
·
2c8f495
1
Parent(s):
405f2d4
Add mask filling app
Browse files- app.py +8 -3
- apps/mlm.py +49 -49
- apps/utils.py +1 -0
- apps/vqa.py +44 -42
- multiapp.py +10 -3
- resize_images.py +10 -3
app.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
| 1 |
from apps import mlm, vqa
|
| 2 |
import os
|
| 3 |
import streamlit as st
|
|
|
|
| 4 |
from multiapp import MultiApp
|
| 5 |
|
|
|
|
| 6 |
def read_markdown(path, parent="./sections/"):
|
| 7 |
with open(os.path.join(parent, path)) as f:
|
| 8 |
return f.read()
|
| 9 |
|
|
|
|
| 10 |
def main():
|
|
|
|
| 11 |
st.set_page_config(
|
| 12 |
page_title="Multilingual VQA",
|
| 13 |
layout="wide",
|
|
@@ -30,7 +34,7 @@ def main():
|
|
| 30 |
st.write(read_markdown("abstract.md"))
|
| 31 |
st.write(read_markdown("caveats.md"))
|
| 32 |
st.write("## Methodology")
|
| 33 |
-
col1, col2 = st.beta_columns([1,1])
|
| 34 |
col1.image(
|
| 35 |
"./misc/article/Multilingual-VQA.png",
|
| 36 |
caption="Masked LM model for Image-text Pretraining.",
|
|
@@ -43,10 +47,11 @@ def main():
|
|
| 43 |
st.write(read_markdown("checkpoints.md"))
|
| 44 |
st.write(read_markdown("acknowledgements.md"))
|
| 45 |
|
| 46 |
-
app = MultiApp()
|
| 47 |
app.add_app("Visual Question Answering", vqa.app)
|
| 48 |
app.add_app("Mask Filling", mlm.app)
|
| 49 |
app.run()
|
|
|
|
| 50 |
|
| 51 |
if __name__ == "__main__":
|
| 52 |
-
main()
|
|
|
|
| 1 |
from apps import mlm, vqa
|
| 2 |
import os
|
| 3 |
import streamlit as st
|
| 4 |
+
from session import _get_state
|
| 5 |
from multiapp import MultiApp
|
| 6 |
|
| 7 |
+
|
| 8 |
def read_markdown(path, parent="./sections/"):
|
| 9 |
with open(os.path.join(parent, path)) as f:
|
| 10 |
return f.read()
|
| 11 |
|
| 12 |
+
|
| 13 |
def main():
|
| 14 |
+
state = _get_state()
|
| 15 |
st.set_page_config(
|
| 16 |
page_title="Multilingual VQA",
|
| 17 |
layout="wide",
|
|
|
|
| 34 |
st.write(read_markdown("abstract.md"))
|
| 35 |
st.write(read_markdown("caveats.md"))
|
| 36 |
st.write("## Methodology")
|
| 37 |
+
col1, col2 = st.beta_columns([1, 1])
|
| 38 |
col1.image(
|
| 39 |
"./misc/article/Multilingual-VQA.png",
|
| 40 |
caption="Masked LM model for Image-text Pretraining.",
|
|
|
|
| 47 |
st.write(read_markdown("checkpoints.md"))
|
| 48 |
st.write(read_markdown("acknowledgements.md"))
|
| 49 |
|
| 50 |
+
app = MultiApp(state)
|
| 51 |
app.add_app("Visual Question Answering", vqa.app)
|
| 52 |
app.add_app("Mask Filling", mlm.app)
|
| 53 |
app.run()
|
| 54 |
+
state.sync()
|
| 55 |
|
| 56 |
if __name__ == "__main__":
|
| 57 |
+
main()
|
apps/mlm.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
-
|
| 2 |
from .utils import (
|
| 3 |
get_text_attributes,
|
| 4 |
get_top_5_predictions,
|
| 5 |
get_transformed_image,
|
| 6 |
plotly_express_horizontal_bar_plot,
|
| 7 |
-
|
| 8 |
-
bert_tokenizer
|
| 9 |
)
|
| 10 |
|
| 11 |
import streamlit as st
|
|
@@ -13,97 +11,99 @@ import numpy as np
|
|
| 13 |
import pandas as pd
|
| 14 |
import os
|
| 15 |
import matplotlib.pyplot as plt
|
| 16 |
-
|
| 17 |
-
from session import _get_state
|
| 18 |
|
| 19 |
|
| 20 |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
| 21 |
FlaxCLIPVisionBertForMaskedLM,
|
| 22 |
)
|
| 23 |
|
|
|
|
| 24 |
def softmax(logits):
|
| 25 |
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
|
| 26 |
|
| 27 |
-
def app():
|
| 28 |
-
|
| 29 |
|
| 30 |
-
@st.cache(persist=False)
|
| 31 |
def predict(transformed_image, caption_inputs):
|
| 32 |
-
outputs =
|
| 33 |
-
indices = np.where(caption_inputs[
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
@st.cache(persist=False)
|
| 43 |
def load_model(ckpt):
|
| 44 |
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
|
| 45 |
|
| 46 |
-
mlm_checkpoints = [
|
| 47 |
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
|
| 48 |
|
| 49 |
first_index = 20
|
| 50 |
-
# Init Session
|
| 51 |
-
if
|
| 52 |
-
|
| 53 |
caption = dummy_data.loc[first_index, "caption"].strip("- ")
|
| 54 |
-
ids = bert_tokenizer(caption)
|
| 55 |
-
ids[np.random.randint(
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
-
image_path = os.path.join("cc12m_data/images_vqa",
|
| 60 |
image = plt.imread(image_path)
|
| 61 |
-
|
| 62 |
|
| 63 |
-
if
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
if st.button(
|
| 69 |
"Get a random example",
|
| 70 |
help="Get a random example from the 100 `seeded` image-text pairs.",
|
| 71 |
):
|
| 72 |
sample = dummy_data.sample(1).reset_index()
|
| 73 |
-
|
| 74 |
caption = sample.loc[0, "caption"].strip("- ")
|
| 75 |
-
ids = bert_tokenizer(caption)
|
| 76 |
-
ids[np.random.randint(
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
-
image_path = os.path.join("cc12m_data/images_vqa",
|
| 81 |
image = plt.imread(image_path)
|
| 82 |
-
|
| 83 |
|
| 84 |
-
transformed_image = get_transformed_image(
|
| 85 |
|
| 86 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
| 87 |
|
| 88 |
# Display Image
|
| 89 |
-
new_col1.image(
|
| 90 |
-
|
| 91 |
|
| 92 |
# Display caption
|
| 93 |
new_col2.write("Write your text with exactly one [MASK] token.")
|
| 94 |
caption = new_col2.text_input(
|
| 95 |
label="Text",
|
| 96 |
-
value=
|
| 97 |
help="Type your masked caption regarding the image above in one of the four languages.",
|
| 98 |
)
|
| 99 |
|
|
|
|
|
|
|
|
|
|
| 100 |
caption_inputs = get_text_attributes(caption)
|
| 101 |
|
| 102 |
# Display Top-5 Predictions
|
| 103 |
-
|
| 104 |
with st.spinner("Predicting..."):
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
labels, values = get_top_5_predictions(
|
|
|
|
| 108 |
fig = plotly_express_horizontal_bar_plot(values, labels)
|
| 109 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from .utils import (
|
| 2 |
get_text_attributes,
|
| 3 |
get_top_5_predictions,
|
| 4 |
get_transformed_image,
|
| 5 |
plotly_express_horizontal_bar_plot,
|
| 6 |
+
bert_tokenizer,
|
|
|
|
| 7 |
)
|
| 8 |
|
| 9 |
import streamlit as st
|
|
|
|
| 11 |
import pandas as pd
|
| 12 |
import os
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
+
from mtranslate import translate
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
| 18 |
FlaxCLIPVisionBertForMaskedLM,
|
| 19 |
)
|
| 20 |
|
| 21 |
+
|
| 22 |
def softmax(logits):
|
| 23 |
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
|
| 24 |
|
| 25 |
+
def app(state):
|
| 26 |
+
mlm_state = state
|
| 27 |
|
| 28 |
+
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
|
| 29 |
def predict(transformed_image, caption_inputs):
|
| 30 |
+
outputs = model(pixel_values=transformed_image, **caption_inputs)
|
| 31 |
+
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[
|
| 32 |
+
1
|
| 33 |
+
][0]
|
| 34 |
+
preds = outputs.logits[0][indices]
|
| 35 |
+
scores = np.array(preds)
|
| 36 |
+
return scores
|
| 37 |
+
|
| 38 |
+
# @st.cache(persist=False)
|
|
|
|
|
|
|
| 39 |
def load_model(ckpt):
|
| 40 |
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
|
| 41 |
|
| 42 |
+
mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
|
| 43 |
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
|
| 44 |
|
| 45 |
first_index = 20
|
| 46 |
+
# Init Session mlm_state
|
| 47 |
+
if mlm_state.mlm_image_file is None:
|
| 48 |
+
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
|
| 49 |
caption = dummy_data.loc[first_index, "caption"].strip("- ")
|
| 50 |
+
ids = bert_tokenizer.encode(caption)
|
| 51 |
+
ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
|
| 52 |
+
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
| 53 |
+
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
|
| 54 |
|
| 55 |
+
image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
|
| 56 |
image = plt.imread(image_path)
|
| 57 |
+
mlm_state.mlm_image = image
|
| 58 |
|
| 59 |
+
#if model is None:
|
| 60 |
+
# Display Top-5 Predictions
|
| 61 |
+
with st.spinner("Loading model..."):
|
| 62 |
+
model = load_model(mlm_checkpoints[0])
|
| 63 |
|
| 64 |
if st.button(
|
| 65 |
"Get a random example",
|
| 66 |
help="Get a random example from the 100 `seeded` image-text pairs.",
|
| 67 |
):
|
| 68 |
sample = dummy_data.sample(1).reset_index()
|
| 69 |
+
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
|
| 70 |
caption = sample.loc[0, "caption"].strip("- ")
|
| 71 |
+
ids = bert_tokenizer.encode(caption)
|
| 72 |
+
ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
|
| 73 |
+
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
| 74 |
+
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
|
| 75 |
|
| 76 |
+
image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
|
| 77 |
image = plt.imread(image_path)
|
| 78 |
+
mlm_state.mlm_image = image
|
| 79 |
|
| 80 |
+
transformed_image = get_transformed_image(mlm_state.mlm_image)
|
| 81 |
|
| 82 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
| 83 |
|
| 84 |
# Display Image
|
| 85 |
+
new_col1.image(mlm_state.mlm_image, use_column_width="always")
|
|
|
|
| 86 |
|
| 87 |
# Display caption
|
| 88 |
new_col2.write("Write your text with exactly one [MASK] token.")
|
| 89 |
caption = new_col2.text_input(
|
| 90 |
label="Text",
|
| 91 |
+
value=mlm_state.caption,
|
| 92 |
help="Type your masked caption regarding the image above in one of the four languages.",
|
| 93 |
)
|
| 94 |
|
| 95 |
+
new_col2.markdown(
|
| 96 |
+
f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
|
| 97 |
+
)
|
| 98 |
caption_inputs = get_text_attributes(caption)
|
| 99 |
|
| 100 |
# Display Top-5 Predictions
|
|
|
|
| 101 |
with st.spinner("Predicting..."):
|
| 102 |
+
scores = predict(transformed_image, dict(caption_inputs))
|
| 103 |
+
scores = softmax(scores)
|
| 104 |
+
labels, values = get_top_5_predictions(scores)
|
| 105 |
+
# newer_col1, newer_col2 = st.beta_columns([6,4])
|
| 106 |
fig = plotly_express_horizontal_bar_plot(values, labels)
|
| 107 |
+
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
|
| 108 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 109 |
+
|
apps/utils.py
CHANGED
|
@@ -40,6 +40,7 @@ def get_transformed_image(image):
|
|
| 40 |
|
| 41 |
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
|
| 42 |
|
|
|
|
| 43 |
def get_text_attributes(text):
|
| 44 |
return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
|
| 45 |
|
|
|
|
| 40 |
|
| 41 |
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
|
| 42 |
|
| 43 |
+
|
| 44 |
def get_text_attributes(text):
|
| 45 |
return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
|
| 46 |
|
apps/vqa.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
from .utils import (
|
| 3 |
get_text_attributes,
|
| 4 |
get_top_5_predictions,
|
|
@@ -15,29 +14,33 @@ import matplotlib.pyplot as plt
|
|
| 15 |
import json
|
| 16 |
|
| 17 |
from mtranslate import translate
|
| 18 |
-
from session import _get_state
|
| 19 |
|
| 20 |
|
| 21 |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
| 22 |
FlaxCLIPVisionBertForSequenceClassification,
|
| 23 |
)
|
| 24 |
|
|
|
|
| 25 |
def softmax(logits):
|
| 26 |
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
|
| 27 |
|
| 28 |
-
def app():
|
| 29 |
-
state = _get_state()
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
return np.array(state.model(pixel_values=transformed_image, **question_inputs)[0][0])
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
@st.cache(persist=
|
| 37 |
def load_model(ckpt):
|
| 38 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
| 39 |
|
| 40 |
-
vqa_checkpoints = [
|
|
|
|
|
|
|
| 41 |
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
|
| 42 |
code_to_name = {
|
| 43 |
"en": "English",
|
|
@@ -46,77 +49,76 @@ def app():
|
|
| 46 |
"es": "Spanish",
|
| 47 |
}
|
| 48 |
|
| 49 |
-
|
| 50 |
with open("answer_reverse_mapping.json") as f:
|
| 51 |
answer_reverse_mapping = json.load(f)
|
| 52 |
|
| 53 |
first_index = 20
|
| 54 |
-
# Init Session
|
| 55 |
-
if
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
image_path = os.path.join("resized_images",
|
| 63 |
image = plt.imread(image_path)
|
| 64 |
-
|
| 65 |
|
| 66 |
-
if
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
| 70 |
|
| 71 |
if st.button(
|
| 72 |
"Get a random example",
|
| 73 |
help="Get a random example from the 100 `seeded` image-text pairs.",
|
| 74 |
):
|
| 75 |
sample = dummy_data.sample(1).reset_index()
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
|
| 82 |
-
image_path = os.path.join("resized_images",
|
| 83 |
image = plt.imread(image_path)
|
| 84 |
-
|
| 85 |
|
| 86 |
-
transformed_image = get_transformed_image(
|
| 87 |
|
| 88 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
| 89 |
|
| 90 |
# Display Image
|
| 91 |
-
new_col1.image(
|
| 92 |
-
|
| 93 |
|
| 94 |
# Display Question
|
| 95 |
question = new_col2.text_input(
|
| 96 |
label="Question",
|
| 97 |
-
value=
|
| 98 |
help="Type your question regarding the image above in one of the four languages.",
|
| 99 |
)
|
| 100 |
new_col2.markdown(
|
| 101 |
-
f"""**English Translation**: {question if
|
| 102 |
)
|
| 103 |
|
| 104 |
question_inputs = get_text_attributes(question)
|
| 105 |
|
| 106 |
# Select Language
|
| 107 |
options = ["en", "de", "es", "fr"]
|
| 108 |
-
|
| 109 |
"Answer Language",
|
| 110 |
-
index=options.index(
|
| 111 |
options=options,
|
| 112 |
format_func=lambda x: code_to_name[x],
|
| 113 |
help="The language to be used to show the top-5 labels.",
|
| 114 |
)
|
| 115 |
|
| 116 |
-
actual_answer = answer_reverse_mapping[str(
|
| 117 |
new_col2.markdown(
|
| 118 |
"**Actual Answer**: "
|
| 119 |
-
+ translate_labels([actual_answer],
|
| 120 |
+ " ("
|
| 121 |
+ actual_answer
|
| 122 |
+ ")"
|
|
@@ -126,6 +128,6 @@ def app():
|
|
| 126 |
logits = predict(transformed_image, dict(question_inputs))
|
| 127 |
logits = softmax(logits)
|
| 128 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
| 129 |
-
translated_labels = translate_labels(labels,
|
| 130 |
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
|
| 131 |
-
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
|
| 1 |
from .utils import (
|
| 2 |
get_text_attributes,
|
| 3 |
get_top_5_predictions,
|
|
|
|
| 14 |
import json
|
| 15 |
|
| 16 |
from mtranslate import translate
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
| 20 |
FlaxCLIPVisionBertForSequenceClassification,
|
| 21 |
)
|
| 22 |
|
| 23 |
+
|
| 24 |
def softmax(logits):
|
| 25 |
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
|
| 26 |
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
def app(state):
|
| 29 |
+
vqa_state = state
|
|
|
|
| 30 |
|
| 31 |
+
# @st.cache(persist=False)
|
| 32 |
+
def predict(transformed_image, question_inputs):
|
| 33 |
+
return np.array(
|
| 34 |
+
model(pixel_values=transformed_image, **question_inputs)[0][0]
|
| 35 |
+
)
|
| 36 |
|
| 37 |
+
# @st.cache(persist=False)
|
| 38 |
def load_model(ckpt):
|
| 39 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
| 40 |
|
| 41 |
+
vqa_checkpoints = [
|
| 42 |
+
"flax-community/clip-vision-bert-vqa-ft-6k"
|
| 43 |
+
] # TODO: Maybe add more checkpoints?
|
| 44 |
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
|
| 45 |
code_to_name = {
|
| 46 |
"en": "English",
|
|
|
|
| 49 |
"es": "Spanish",
|
| 50 |
}
|
| 51 |
|
|
|
|
| 52 |
with open("answer_reverse_mapping.json") as f:
|
| 53 |
answer_reverse_mapping = json.load(f)
|
| 54 |
|
| 55 |
first_index = 20
|
| 56 |
+
# Init Session vqa_state
|
| 57 |
+
if vqa_state.vqa_image_file is None:
|
| 58 |
+
vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"]
|
| 59 |
+
vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ")
|
| 60 |
+
vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"]
|
| 61 |
+
vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
|
| 62 |
+
vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
|
| 63 |
+
|
| 64 |
+
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
|
| 65 |
image = plt.imread(image_path)
|
| 66 |
+
vqa_state.vqa_image = image
|
| 67 |
|
| 68 |
+
# if model is None:
|
| 69 |
+
|
| 70 |
+
# Display Top-5 Predictions
|
| 71 |
+
with st.spinner("Loading model..."):
|
| 72 |
+
model = load_model(vqa_checkpoints[0])
|
| 73 |
|
| 74 |
if st.button(
|
| 75 |
"Get a random example",
|
| 76 |
help="Get a random example from the 100 `seeded` image-text pairs.",
|
| 77 |
):
|
| 78 |
sample = dummy_data.sample(1).reset_index()
|
| 79 |
+
vqa_state.vqa_image_file = sample.loc[0, "image_file"]
|
| 80 |
+
vqa_state.question = sample.loc[0, "question"].strip("- ")
|
| 81 |
+
vqa_state.answer_label = sample.loc[0, "answer_label"]
|
| 82 |
+
vqa_state.question_lang_id = sample.loc[0, "lang_id"]
|
| 83 |
+
vqa_state.answer_lang_id = sample.loc[0, "lang_id"]
|
| 84 |
|
| 85 |
+
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
|
| 86 |
image = plt.imread(image_path)
|
| 87 |
+
vqa_state.vqa_image = image
|
| 88 |
|
| 89 |
+
transformed_image = get_transformed_image(vqa_state.vqa_image)
|
| 90 |
|
| 91 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
| 92 |
|
| 93 |
# Display Image
|
| 94 |
+
new_col1.image(vqa_state.vqa_image, use_column_width="always")
|
|
|
|
| 95 |
|
| 96 |
# Display Question
|
| 97 |
question = new_col2.text_input(
|
| 98 |
label="Question",
|
| 99 |
+
value=vqa_state.question,
|
| 100 |
help="Type your question regarding the image above in one of the four languages.",
|
| 101 |
)
|
| 102 |
new_col2.markdown(
|
| 103 |
+
f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}"""
|
| 104 |
)
|
| 105 |
|
| 106 |
question_inputs = get_text_attributes(question)
|
| 107 |
|
| 108 |
# Select Language
|
| 109 |
options = ["en", "de", "es", "fr"]
|
| 110 |
+
vqa_state.answer_lang_id = new_col2.selectbox(
|
| 111 |
"Answer Language",
|
| 112 |
+
index=options.index(vqa_state.answer_lang_id),
|
| 113 |
options=options,
|
| 114 |
format_func=lambda x: code_to_name[x],
|
| 115 |
help="The language to be used to show the top-5 labels.",
|
| 116 |
)
|
| 117 |
|
| 118 |
+
actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)]
|
| 119 |
new_col2.markdown(
|
| 120 |
"**Actual Answer**: "
|
| 121 |
+
+ translate_labels([actual_answer], vqa_state.answer_lang_id)[0]
|
| 122 |
+ " ("
|
| 123 |
+ actual_answer
|
| 124 |
+ ")"
|
|
|
|
| 128 |
logits = predict(transformed_image, dict(question_inputs))
|
| 129 |
logits = softmax(logits)
|
| 130 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
| 131 |
+
translated_labels = translate_labels(labels, vqa_state.answer_lang_id)
|
| 132 |
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
|
| 133 |
+
st.plotly_chart(fig, use_container_width=True)
|
multiapp.py
CHANGED
|
@@ -1,10 +1,17 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
| 2 |
class MultiApp:
|
| 3 |
-
def __init__(self):
|
| 4 |
self.apps = []
|
|
|
|
|
|
|
| 5 |
def add_app(self, title, func):
|
| 6 |
self.apps.append({"title": title, "function": func})
|
|
|
|
| 7 |
def run(self):
|
| 8 |
st.sidebar.header("Tasks")
|
| 9 |
-
app = st.sidebar.radio(
|
| 10 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from session import _get_state
|
| 3 |
+
|
| 4 |
class MultiApp:
|
| 5 |
+
def __init__(self, state):
|
| 6 |
self.apps = []
|
| 7 |
+
self.state = state
|
| 8 |
+
|
| 9 |
def add_app(self, title, func):
|
| 10 |
self.apps.append({"title": title, "function": func})
|
| 11 |
+
|
| 12 |
def run(self):
|
| 13 |
st.sidebar.header("Tasks")
|
| 14 |
+
app = st.sidebar.radio(
|
| 15 |
+
"", self.apps, format_func=lambda app: app["title"]
|
| 16 |
+
)
|
| 17 |
+
app["function"](self.state)
|
resize_images.py
CHANGED
|
@@ -7,7 +7,11 @@ def resize_images(path, new_path, num_pixels=300):
|
|
| 7 |
if not os.path.exists(new_path):
|
| 8 |
os.makedirs(new_path)
|
| 9 |
for filename in os.listdir(path):
|
| 10 |
-
if not filename.startswith(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
img = cv2.imread(os.path.join(path, filename))
|
| 12 |
height, width, channels = img.shape
|
| 13 |
if height > width:
|
|
@@ -16,8 +20,11 @@ def resize_images(path, new_path, num_pixels=300):
|
|
| 16 |
else:
|
| 17 |
new_width = num_pixels
|
| 18 |
new_height = int(height * new_width / width)
|
| 19 |
-
img = cv2.resize(
|
|
|
|
|
|
|
| 20 |
cv2.imwrite(os.path.join(new_path, filename), img)
|
| 21 |
|
|
|
|
| 22 |
# resize_images('./images/val2014', './resized_images/val2014')
|
| 23 |
-
resize_images(
|
|
|
|
| 7 |
if not os.path.exists(new_path):
|
| 8 |
os.makedirs(new_path)
|
| 9 |
for filename in os.listdir(path):
|
| 10 |
+
if not filename.startswith(".") and (
|
| 11 |
+
filename.endswith(".jpg")
|
| 12 |
+
or filename.endswith(".jpeg")
|
| 13 |
+
or filename.endswith(".png")
|
| 14 |
+
):
|
| 15 |
img = cv2.imread(os.path.join(path, filename))
|
| 16 |
height, width, channels = img.shape
|
| 17 |
if height > width:
|
|
|
|
| 20 |
else:
|
| 21 |
new_width = num_pixels
|
| 22 |
new_height = int(height * new_width / width)
|
| 23 |
+
img = cv2.resize(
|
| 24 |
+
img, (new_width, new_height), interpolation=cv2.INTER_CUBIC
|
| 25 |
+
)
|
| 26 |
cv2.imwrite(os.path.join(new_path, filename), img)
|
| 27 |
|
| 28 |
+
|
| 29 |
# resize_images('./images/val2014', './resized_images/val2014')
|
| 30 |
+
resize_images("./misc/article", "./misc/article/resized", 500)
|