Spaces:
Runtime error
Runtime error
ydshieh
commited on
Commit
·
c951094
1
Parent(s):
5bedd3a
update model
Browse files
app.py
CHANGED
|
@@ -1,19 +1,17 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from PIL import Image
|
| 3 |
-
import numpy as np
|
| 4 |
|
| 5 |
|
| 6 |
# Designing the interface
|
| 7 |
-
st.title("🖼️
|
| 8 |
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
|
| 9 |
|
| 10 |
st.sidebar.markdown(
|
| 11 |
"""
|
| 12 |
-
An image captioning model [ViT-GPT2](https://huggingface.co/flax-community/vit-gpt2) by combining the ViT model
|
| 13 |
[Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
|
| 14 |
-
The
|
| 15 |
-
The
|
| 16 |
-
The model is trained on
|
| 17 |
"""
|
| 18 |
)
|
| 19 |
|
|
@@ -21,6 +19,7 @@ st.sidebar.markdown(
|
|
| 21 |
#show = st.image(image, use_column_width=True)
|
| 22 |
#show.image(image, 'Preloaded Image', use_column_width=True)
|
| 23 |
|
|
|
|
| 24 |
with st.spinner('Loading and compiling ViT-GPT2 model ...'):
|
| 25 |
|
| 26 |
from model import *
|
|
@@ -43,16 +42,21 @@ show.image(image, '\n\nSelected Image', width=480)
|
|
| 43 |
# For newline
|
| 44 |
st.sidebar.write('\n')
|
| 45 |
|
|
|
|
| 46 |
with st.spinner('Generating image caption ...'):
|
| 47 |
|
| 48 |
caption = predict(image)
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
caption_en = translator.translate(caption, src='fr', dest='en').text
|
| 51 |
-
st.header(f'**Prediction (in French) **{caption}')
|
| 52 |
-
st.header(f'**English Translation**: {caption_en}')
|
|
|
|
| 53 |
|
| 54 |
st.sidebar.header("ViT-GPT2 predicts:")
|
| 55 |
-
st.sidebar.write(f"**
|
| 56 |
-
|
| 57 |
|
| 58 |
-
image.close()
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
# Designing the interface
|
| 5 |
+
st.title("🖼️ Image Captioning Demo 📝")
|
| 6 |
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
|
| 7 |
|
| 8 |
st.sidebar.markdown(
|
| 9 |
"""
|
| 10 |
+
An image captioning model [ViT-GPT2](https://huggingface.co/flax-community/vit-gpt2) by combining the ViT model with the GPT2 model.
|
| 11 |
[Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
|
| 12 |
+
The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' `FlaxVisionEncoderDecoderModel`.
|
| 13 |
+
The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
|
| 14 |
+
The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
|
| 15 |
"""
|
| 16 |
)
|
| 17 |
|
|
|
|
| 19 |
#show = st.image(image, use_column_width=True)
|
| 20 |
#show.image(image, 'Preloaded Image', use_column_width=True)
|
| 21 |
|
| 22 |
+
|
| 23 |
with st.spinner('Loading and compiling ViT-GPT2 model ...'):
|
| 24 |
|
| 25 |
from model import *
|
|
|
|
| 42 |
# For newline
|
| 43 |
st.sidebar.write('\n')
|
| 44 |
|
| 45 |
+
|
| 46 |
with st.spinner('Generating image caption ...'):
|
| 47 |
|
| 48 |
caption = predict(image)
|
| 49 |
+
|
| 50 |
+
caption_en = caption
|
| 51 |
+
st.header(f'**Prediction (in English) **{caption_en}')
|
| 52 |
|
| 53 |
+
# caption_en = translator.translate(caption, src='fr', dest='en').text
|
| 54 |
+
# st.header(f'**Prediction (in French) **{caption}')
|
| 55 |
+
# st.header(f'**English Translation**: {caption_en}')
|
| 56 |
+
|
| 57 |
|
| 58 |
st.sidebar.header("ViT-GPT2 predicts:")
|
| 59 |
+
st.sidebar.write(f"**English**: {caption}")
|
| 60 |
+
|
| 61 |
|
| 62 |
+
image.close()
|
model.py
CHANGED
|
@@ -1,83 +1,67 @@
|
|
| 1 |
-
import os,
|
| 2 |
-
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
-
|
| 5 |
import jax
|
| 6 |
-
from transformers import ViTFeatureExtractor
|
| 7 |
-
from transformers import GPT2Tokenizer
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
| 10 |
from googletrans import Translator
|
| 11 |
translator = Translator()
|
| 12 |
|
| 13 |
-
current_path = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
-
sys.path.append(current_path)
|
| 15 |
-
|
| 16 |
-
# Main model - ViTGPT2LM
|
| 17 |
-
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
|
| 18 |
|
| 19 |
# create target model directory
|
| 20 |
model_dir = './models/'
|
| 21 |
os.makedirs(model_dir, exist_ok=True)
|
| 22 |
-
# copy config file
|
| 23 |
-
filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/config.json")
|
| 24 |
-
shutil.copyfile(filepath, os.path.join(model_dir, 'config.json'))
|
| 25 |
-
# copy model file
|
| 26 |
-
filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/flax_model.msgpack")
|
| 27 |
-
shutil.copyfile(filepath, os.path.join(model_dir, 'flax_model.msgpack'))
|
| 28 |
-
|
| 29 |
-
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_dir)
|
| 30 |
-
|
| 31 |
-
vit_model_name = 'google/vit-base-patch16-224-in21k'
|
| 32 |
-
feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
| 40 |
|
| 41 |
|
| 42 |
@jax.jit
|
| 43 |
-
def
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
return flax_vit_gpt2_lm.generate(pixel_values, **gen_kwargs)
|
| 46 |
|
| 47 |
def predict(image):
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# generation
|
| 54 |
-
generation = predict_fn(pixel_values)
|
| 55 |
|
| 56 |
-
|
| 57 |
-
caption = tokenizer.decode(token_ids)
|
| 58 |
-
caption = caption.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
|
| 59 |
-
caption = caption.replace("à l'arrière-plan", '').replace("Une photo en noir et blanc d'", '').replace("Une photo noire et blanche d'", '').replace("en arrière-plan", '').replace("Un gros plan d'", '').replace("un gros plan d'", '').replace("Une image d'", '')
|
| 60 |
-
while ' ' in caption:
|
| 61 |
-
caption = caption.replace(' ', ' ')
|
| 62 |
-
caption = caption.strip()
|
| 63 |
-
if caption:
|
| 64 |
-
caption = caption[0].upper() + caption[1:]
|
| 65 |
|
| 66 |
-
return caption
|
| 67 |
|
| 68 |
-
def
|
| 69 |
|
| 70 |
image_path = 'samples/val_000000039769.jpg'
|
| 71 |
image = Image.open(image_path)
|
| 72 |
-
|
| 73 |
caption = predict(image)
|
| 74 |
image.close()
|
| 75 |
|
| 76 |
-
def predict_dummy(image):
|
| 77 |
-
|
| 78 |
-
return 'dummy caption!'
|
| 79 |
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
sample_dir = './samples/'
|
| 83 |
sample_fns = tuple([f"{int(f.replace('COCO_val2014_', '').replace('.jpg', ''))}.jpg" for f in os.listdir(sample_dir) if f.startswith('COCO_val2014_')])
|
|
|
|
| 1 |
+
import os, shutil
|
|
|
|
| 2 |
from PIL import Image
|
|
|
|
| 3 |
import jax
|
| 4 |
+
from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
|
|
|
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
|
| 7 |
from googletrans import Translator
|
| 8 |
translator = Translator()
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# create target model directory
|
| 12 |
model_dir = './models/'
|
| 13 |
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
files_to_download = [
|
| 16 |
+
"config.json",
|
| 17 |
+
"flax_model.msgpack",
|
| 18 |
+
"merges.txt",
|
| 19 |
+
"special_tokens_map.json",
|
| 20 |
+
"tokenizer.json",
|
| 21 |
+
"tokenizer_config.json",
|
| 22 |
+
"vocab.json",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# copy files from checkpoint hub:
|
| 26 |
+
for fn in files_to_download:
|
| 27 |
+
file_path = hf_hub_download("ydshieh/vit-gpt2-coco-en", f"ckpt_epoch_3_step_6900/{fn}")
|
| 28 |
+
shutil.copyfile(file_path, os.path.join(model_dir, fn))
|
| 29 |
+
|
| 30 |
+
model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
|
| 31 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
| 33 |
+
|
| 34 |
+
max_length = 16
|
| 35 |
+
num_beams = 4
|
| 36 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
| 37 |
|
| 38 |
|
| 39 |
@jax.jit
|
| 40 |
+
def generate(pixel_values):
|
| 41 |
+
output_ids = model.generate(pixel_values, **gen_kwargs).sequences
|
| 42 |
+
return output_ids
|
| 43 |
|
|
|
|
| 44 |
|
| 45 |
def predict(image):
|
| 46 |
|
| 47 |
+
pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
|
| 48 |
+
output_ids = generate(pixel_values)
|
| 49 |
+
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
| 50 |
+
preds = [pred.strip() for pred in preds]
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
return preds[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
| 54 |
|
| 55 |
+
def _compile():
|
| 56 |
|
| 57 |
image_path = 'samples/val_000000039769.jpg'
|
| 58 |
image = Image.open(image_path)
|
|
|
|
| 59 |
caption = predict(image)
|
| 60 |
image.close()
|
| 61 |
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
_compile()
|
| 64 |
+
|
| 65 |
|
| 66 |
sample_dir = './samples/'
|
| 67 |
sample_fns = tuple([f"{int(f.replace('COCO_val2014_', '').replace('.jpg', ''))}.jpg" for f in os.listdir(sample_dir) if f.startswith('COCO_val2014_')])
|