Spaces:
Runtime error
Runtime error
Commit
·
f384719
1
Parent(s):
571a3f6
Add auto scaling image
Browse files- apps/mlm.py +18 -6
- apps/vqa.py +1 -1
apps/mlm.py
CHANGED
|
@@ -50,8 +50,11 @@ def app(state):
|
|
| 50 |
if mlm_state.mlm_image_file is None:
|
| 51 |
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
|
| 52 |
caption = dummy_data.loc[first_index, "caption"].strip("- ")
|
|
|
|
| 53 |
ids = bert_tokenizer.encode(caption)
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
| 56 |
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
|
| 57 |
|
|
@@ -72,8 +75,11 @@ def app(state):
|
|
| 72 |
sample = dummy_data.sample(1).reset_index()
|
| 73 |
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
|
| 74 |
caption = sample.loc[0, "caption"].strip("- ")
|
|
|
|
| 75 |
ids = bert_tokenizer.encode(caption)
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
| 78 |
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
|
| 79 |
|
|
@@ -99,7 +105,7 @@ def app(state):
|
|
| 99 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
| 100 |
|
| 101 |
# Display Image
|
| 102 |
-
new_col1.image(mlm_state.mlm_image, use_column_width="
|
| 103 |
|
| 104 |
# Display caption
|
| 105 |
new_col2.write("Write your text with exactly one [MASK] token.")
|
|
@@ -109,9 +115,14 @@ def app(state):
|
|
| 109 |
help="Type your masked caption regarding the image above in one of the four languages.",
|
| 110 |
)
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
caption_inputs = get_text_attributes(caption)
|
| 116 |
|
| 117 |
# Display Top-5 Predictions
|
|
@@ -119,6 +130,7 @@ def app(state):
|
|
| 119 |
scores = predict(transformed_image, dict(caption_inputs))
|
| 120 |
scores = softmax(scores)
|
| 121 |
labels, values = get_top_5_predictions(scores)
|
|
|
|
| 122 |
# newer_col1, newer_col2 = st.beta_columns([6,4])
|
| 123 |
fig = plotly_express_horizontal_bar_plot(values, labels)
|
| 124 |
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
|
|
|
|
| 50 |
if mlm_state.mlm_image_file is None:
|
| 51 |
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
|
| 52 |
caption = dummy_data.loc[first_index, "caption"].strip("- ")
|
| 53 |
+
mlm_state.unmasked_caption = caption
|
| 54 |
ids = bert_tokenizer.encode(caption)
|
| 55 |
+
mask_index = np.random.randint(1, len(ids) - 1)
|
| 56 |
+
mlm_state.currently_masked_token = ids[mask_index]
|
| 57 |
+
ids[mask_index] = bert_tokenizer.mask_token_id
|
| 58 |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
| 59 |
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
|
| 60 |
|
|
|
|
| 75 |
sample = dummy_data.sample(1).reset_index()
|
| 76 |
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
|
| 77 |
caption = sample.loc[0, "caption"].strip("- ")
|
| 78 |
+
mlm_state.unmasked_caption = caption
|
| 79 |
ids = bert_tokenizer.encode(caption)
|
| 80 |
+
mask_index = np.random.randint(1, len(ids) - 1)
|
| 81 |
+
mlm_state.currently_masked_token = ids[mask_index]
|
| 82 |
+
ids[mask_index] = bert_tokenizer.mask_token_id
|
| 83 |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
| 84 |
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
|
| 85 |
|
|
|
|
| 105 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
| 106 |
|
| 107 |
# Display Image
|
| 108 |
+
new_col1.image(mlm_state.mlm_image, use_column_width="auto")
|
| 109 |
|
| 110 |
# Display caption
|
| 111 |
new_col2.write("Write your text with exactly one [MASK] token.")
|
|
|
|
| 115 |
help="Type your masked caption regarding the image above in one of the four languages.",
|
| 116 |
)
|
| 117 |
|
| 118 |
+
if caption == mlm_state.caption:
|
| 119 |
+
new_col2.markdown("**Masked Token**: "+mlm_state.currently_masked_token)
|
| 120 |
+
new_col2.markdown("**English Translation: " + mlm_state.unmasked_caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.unmasked_caption, 'en'))
|
| 121 |
+
|
| 122 |
+
else:
|
| 123 |
+
new_col2.markdown(
|
| 124 |
+
f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
|
| 125 |
+
)
|
| 126 |
caption_inputs = get_text_attributes(caption)
|
| 127 |
|
| 128 |
# Display Top-5 Predictions
|
|
|
|
| 130 |
scores = predict(transformed_image, dict(caption_inputs))
|
| 131 |
scores = softmax(scores)
|
| 132 |
labels, values = get_top_5_predictions(scores)
|
| 133 |
+
print(labels)
|
| 134 |
# newer_col1, newer_col2 = st.beta_columns([6,4])
|
| 135 |
fig = plotly_express_horizontal_bar_plot(values, labels)
|
| 136 |
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
|
apps/vqa.py
CHANGED
|
@@ -109,7 +109,7 @@ def app(state):
|
|
| 109 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
| 110 |
|
| 111 |
# Display Image
|
| 112 |
-
new_col1.image(vqa_state.vqa_image, use_column_width="
|
| 113 |
|
| 114 |
# Display Question
|
| 115 |
question = new_col2.text_input(
|
|
|
|
| 109 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
| 110 |
|
| 111 |
# Display Image
|
| 112 |
+
new_col1.image(vqa_state.vqa_image, use_column_width="auto")
|
| 113 |
|
| 114 |
# Display Question
|
| 115 |
question = new_col2.text_input(
|