Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from collections import defaultdict | |
| import streamlit as st | |
| from streamlit_drawable_canvas import st_canvas | |
| import matplotlib as mpl | |
| from model import device, segment_image, inpaint | |
| # define utils and helpers | |
| def closest_number(n, m=8): | |
| """ Obtains closest number to n that is divisble by m """ | |
| return int(n/m) * m | |
| def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'): | |
| # Create a canvas component | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.3)", | |
| stroke_width=2, | |
| stroke_color="#000000", | |
| background_image=image, | |
| update_streamlit=True, | |
| height=height, | |
| width=width, | |
| drawing_mode=drawing_mode, | |
| point_display_radius=5, | |
| key="canvas", | |
| ) | |
| # get selections from mask | |
| if canvas_result.json_data is not None: | |
| objects = pd.json_normalize(canvas_result.json_data["objects"]) | |
| for col in objects.select_dtypes(include=["object"]).columns: | |
| objects[col] = objects[col].astype("str") | |
| if len(objects) > 0: | |
| left_coords = objects.left.to_numpy() | |
| top_coords = objects.top.to_numpy() | |
| right_coords = left_coords + objects.width.to_numpy() | |
| bottom_coords = top_coords + objects.height.to_numpy() | |
| # add selections to mask | |
| for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords): | |
| cropped = image.crop((left, top, right, bottom)) | |
| st.image(cropped) | |
| mask[top:bottom, left:right] = 255 | |
| st.header("Mask Created!") | |
| st.image(mask) | |
| return mask | |
| def get_mask(image, edit_method, height, width): | |
| mask = np.zeros((height, width), dtype=np.uint8) | |
| if edit_method == "AutoSegment Area": | |
| # get displayable segmented image | |
| seg_prediction, segment_labels = segment_image(image) | |
| seg = seg_prediction['segmentation'].cpu().numpy() | |
| viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg)) | |
| seg_image = Image.fromarray(np.uint8(viridis(seg)*255)) | |
| st.image(seg_image) | |
| # prompt user to select valid labels to edit | |
| seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values())) | |
| if seg_selections: | |
| tgts = [] | |
| for s in seg_selections: | |
| tgts.append(s[0]) | |
| mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255) | |
| st.header("Mask Created!") | |
| st.image(mask) | |
| elif edit_method == "Draw Custom Area": | |
| mask = get_mask_from_rectangles(image, mask, height, width) | |
| return mask | |
| if __name__ == '__main__': | |
| st.title("Stable Edit") | |
| st.title("Edit your photos with Stable Diffusion!") | |
| st.write(f"Device found: {device}") | |
| sf = st.text_input("Please enter resizing scale factor to downsize image (default=2)", value="2") | |
| try: | |
| sf = int(sf) | |
| except: | |
| sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it") | |
| sf = 2 | |
| # upload image | |
| filename = st.file_uploader("upload an image") | |
| if filename: | |
| image = Image.open(filename) | |
| width, height = image.size | |
| width, height = closest_number(width/sf), closest_number(height/sf) | |
| image = image.resize((width, height)) | |
| st.image(image) | |
| # st.write(f"{width} {height}") | |
| # Select an editing method | |
| edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area")) | |
| if edit_method: | |
| mask = get_mask(image, edit_method, height, width) | |
| # get inpainted images | |
| prompt = st.text_input("Please enter prompt for image inpainting", value="") | |
| if prompt: # and isinstance(seed, int): | |
| st.write("Inpainting Images, patience is a virtue and this will take a while to run on a CPU :)") | |
| images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3) | |
| # display all images | |
| st.write("Original Image") | |
| st.image(image) | |
| for i, img in enumerate(images, 1): | |
| st.write(f"result: {i}") | |
| st.image(img) | |