Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import streamlit as st | |
| import requests | |
| import io | |
| # Designing the interface | |
| st.title("🖼️ Image Captioning Demo 📝") | |
| st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)") | |
| st.sidebar.markdown( | |
| """ | |
| An image captioning model by combining ViT model with GPT2 model. | |
| The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder | |
| framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html). | |
| The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights. | |
| The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256). | |
| [Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n | |
| """ | |
| ) | |
| with st.spinner('Loading and compiling ViT-GPT2 model ...'): | |
| from model import * | |
| random_image_id = get_random_image_id() | |
| st.sidebar.title("Select a sample image") | |
| sample_image_id = st.sidebar.selectbox( | |
| "Please choose a sample image", | |
| sample_image_ids | |
| ) | |
| if st.sidebar.button("Random COCO 2017 (val) images"): | |
| random_image_id = get_random_image_id() | |
| sample_image_id = "None" | |
| bytes_data = None | |
| with st.sidebar.form("file-uploader-form", clear_on_submit=True): | |
| uploaded_file = st.file_uploader("Choose a file") | |
| submitted = st.form_submit_button("Upload") | |
| if submitted and uploaded_file is not None: | |
| bytes_data = io.BytesIO(uploaded_file.getvalue()) | |
| if (bytes_data is None) and submitted: | |
| st.write("No file is selected to upload") | |
| else: | |
| image_id = random_image_id | |
| if sample_image_id != "None": | |
| assert type(sample_image_id) == int | |
| image_id = sample_image_id | |
| sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg" | |
| sample_path = os.path.join(sample_dir, sample_name) | |
| if bytes_data is not None: | |
| image = Image.open(bytes_data) | |
| elif os.path.isfile(sample_path): | |
| image = Image.open(sample_path) | |
| else: | |
| url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg" | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| width, height = image.size | |
| resized = image.resize(size=(width, height)) | |
| if height > 384: | |
| width = int(width / height * 384) | |
| height = 384 | |
| resized = resized.resize(size=(width, height)) | |
| width, height = resized.size | |
| if width > 512: | |
| width = 512 | |
| height = int(height / width * 512) | |
| resized = resized.resize(size=(width, height)) | |
| if bytes_data is None: | |
| st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)") | |
| show = st.image(resized) | |
| show.image(resized, '\n\nSelected Image') | |
| resized.close() | |
| # For newline | |
| st.sidebar.write('\n') | |
| with st.spinner('Generating image caption ...'): | |
| caption = predict(image) | |
| caption_en = caption | |
| st.header(f'Predicted caption:\n\n') | |
| st.subheader(caption_en) | |
| st.sidebar.header("ViT-GPT2 predicts: ") | |
| st.sidebar.write(f"{caption}") | |
| image.close() | |
