Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| import urllib | |
| import fastai.vision.all as fai_vision | |
| import numpy as np | |
| from pathlib import Path | |
| import pathlib | |
| from PIL import Image | |
| import platform | |
| import altair as alt | |
| import pandas as pd | |
| import frontmatter | |
| def main(): | |
| st.title('Fish Masker and Classifier') | |
| with open('README.md') as readme_file: | |
| readme = frontmatter.load(readme_file) | |
| st.markdown(readme.content) | |
| data_loader, segmenter = load_unet_model() | |
| classification_model = load_classification_model() | |
| st.markdown("## Instructions") | |
| st.markdown("Upload an Amazonian fish photo for masking.") | |
| uploaded_image = st.file_uploader("", IMAGE_TYPES) | |
| if uploaded_image: | |
| image_data = uploaded_image.read() | |
| st.markdown('## Original image') | |
| st.image(image_data, use_column_width=True) | |
| original_pil = Image.open(uploaded_image) | |
| original_pil.save('original.jpg') | |
| single_file = [Path('original.jpg')] | |
| single_pil = Image.open(single_file[0]) | |
| input_dl = segmenter.dls.test_dl(single_file) | |
| masks, _ = segmenter.get_preds(dl=input_dl) | |
| masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0]) | |
| st.markdown('## Masked image') | |
| st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"') | |
| st.image(masked_pil, use_column_width=True) | |
| masked_pil.save('masked.jpg') | |
| st.markdown('## Classification') | |
| prediction = classification_model.predict('masked.jpg') | |
| pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab) | |
| st.altair_chart(pred_chart, use_container_width=True) | |
| def mask_fish_pil(unmasked_fish, fastai_mask): | |
| unmasked_np = np.array(unmasked_fish) | |
| np_mask = fastai_mask.argmax(dim=0).numpy() | |
| total_pixels = np_mask.size | |
| fish_pixels = np.count_nonzero(np_mask) | |
| percentage_fish = (fish_pixels / total_pixels) * 100 | |
| np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8) | |
| np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR)) | |
| np_mask = np_mask.reshape(*np_mask.shape, 1) / 255 | |
| masked_fish_np = (unmasked_np * np_mask).astype(np.uint8) | |
| masked_fish_pil = Image.fromarray(masked_fish_np) | |
| return masked_fish_pil, percentage_fish | |
| def predictions_to_chart(prediction, classes): | |
| pred_rows = [] | |
| for i, conf in enumerate(list(prediction[2])): | |
| pred_row = {'class': classes[i], | |
| 'probability': round(float(conf) * 100,2)} | |
| pred_rows.append(pred_row) | |
| pred_df = pd.DataFrame(pred_rows) | |
| pred_df.head() | |
| top_probs = pred_df.sort_values('probability', ascending=False).head(4) | |
| chart = ( | |
| alt.Chart(top_probs) | |
| .mark_bar() | |
| .encode( | |
| x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))), | |
| y=alt.Y("class:N", | |
| sort=alt.EncodingSortField(field="probability", order="descending")) | |
| ) | |
| ) | |
| return chart | |
| def load_unet_model(): | |
| data_loader = fai_vision.SegmentationDataLoaders.from_label_func( | |
| path = Path("."), | |
| bs = 1, | |
| fnames = [Path('test_fish.jpg')], | |
| label_func = lambda x: x, | |
| codes = np.array(["Photo", "Masks"], dtype=str), | |
| item_tfms = [fai_vision.Resize(256, method = 'squish'),], | |
| batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)], | |
| valid_pct = 0.2, num_workers = 0) | |
| segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34) | |
| segmenter.load('fish_mask_model') | |
| return data_loader, segmenter | |
| def load_classification_model(): | |
| plt = platform.system() | |
| if plt == 'Linux' or plt == 'Darwin': | |
| pathlib.WindowsPath = pathlib.PosixPath | |
| inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True) | |
| return inf_model | |
| IMAGE_TYPES = ["png", "jpg","jpeg"] | |
| if __name__ == "__main__": | |
| main() |