Spaces:
Runtime error
Runtime error
Commit
·
4b29c6a
1
Parent(s):
0cb8576
Fix state model issue
Browse files- apps/mlm.py +7 -6
- apps/vqa.py +3 -3
apps/mlm.py
CHANGED
|
@@ -27,12 +27,13 @@ def app(state):
|
|
| 27 |
|
| 28 |
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
|
| 29 |
def predict(transformed_image, caption_inputs):
|
| 30 |
-
outputs = mlm_state.
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
preds = outputs.logits[0][indices]
|
| 35 |
scores = np.array(preds)
|
|
|
|
| 36 |
return scores
|
| 37 |
|
| 38 |
# @st.cache(persist=False)
|
|
@@ -56,10 +57,10 @@ def app(state):
|
|
| 56 |
image = plt.imread(image_path)
|
| 57 |
mlm_state.mlm_image = image
|
| 58 |
|
| 59 |
-
if mlm_state.
|
| 60 |
# Display Top-5 Predictions
|
| 61 |
with st.spinner("Loading model..."):
|
| 62 |
-
mlm_state.
|
| 63 |
|
| 64 |
if st.button(
|
| 65 |
"Get a random example",
|
|
|
|
| 27 |
|
| 28 |
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
|
| 29 |
def predict(transformed_image, caption_inputs):
|
| 30 |
+
outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs)
|
| 31 |
+
print(outputs.logits.shape)
|
| 32 |
+
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0]
|
| 33 |
+
print(indices)
|
| 34 |
preds = outputs.logits[0][indices]
|
| 35 |
scores = np.array(preds)
|
| 36 |
+
print(scores)
|
| 37 |
return scores
|
| 38 |
|
| 39 |
# @st.cache(persist=False)
|
|
|
|
| 57 |
image = plt.imread(image_path)
|
| 58 |
mlm_state.mlm_image = image
|
| 59 |
|
| 60 |
+
if mlm_state.mlm_model is None:
|
| 61 |
# Display Top-5 Predictions
|
| 62 |
with st.spinner("Loading model..."):
|
| 63 |
+
mlm_state.mlm_model = load_model(mlm_checkpoints[0])
|
| 64 |
|
| 65 |
if st.button(
|
| 66 |
"Get a random example",
|
apps/vqa.py
CHANGED
|
@@ -31,7 +31,7 @@ def app(state):
|
|
| 31 |
# @st.cache(persist=False)
|
| 32 |
def predict(transformed_image, question_inputs):
|
| 33 |
return np.array(
|
| 34 |
-
vqa_state.
|
| 35 |
)
|
| 36 |
|
| 37 |
# @st.cache(persist=False)
|
|
@@ -65,9 +65,9 @@ def app(state):
|
|
| 65 |
image = plt.imread(image_path)
|
| 66 |
vqa_state.vqa_image = image
|
| 67 |
|
| 68 |
-
if vqa_state.
|
| 69 |
with st.spinner("Loading model..."):
|
| 70 |
-
vqa_state.
|
| 71 |
|
| 72 |
# Display Top-5 Predictions
|
| 73 |
|
|
|
|
| 31 |
# @st.cache(persist=False)
|
| 32 |
def predict(transformed_image, question_inputs):
|
| 33 |
return np.array(
|
| 34 |
+
vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0]
|
| 35 |
)
|
| 36 |
|
| 37 |
# @st.cache(persist=False)
|
|
|
|
| 65 |
image = plt.imread(image_path)
|
| 66 |
vqa_state.vqa_image = image
|
| 67 |
|
| 68 |
+
if vqa_state.vqa_model is None:
|
| 69 |
with st.spinner("Loading model..."):
|
| 70 |
+
vqa_state.vqa_model = load_model(vqa_checkpoints[0])
|
| 71 |
|
| 72 |
# Display Top-5 Predictions
|
| 73 |
|