import gradio as gr import numpy as np from PIL import Image import random import matplotlib.pyplot as plt import torch from transformers import SegformerForSemanticSegmentation from transformers import SegformerImageProcessor image_list = [ "data/1.png", "data/2.png", "data/3.png", "data/4.png", ] def visualize_instance_seg_mask(mask): # Initialize image with zeros with the image resolution # of the segmentation mask and 3 channels image = np.zeros((mask.shape[0], mask.shape[1], 3)) # Create labels labels = np.unique(mask) label2color = { label: ( random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), ) for label in labels } for height in range(image.shape[0]): for width in range(image.shape[1]): image[height, width, :] = label2color[mask[height, width]] image = image / 255 return image def Segformer_Segmentation(image_path, model_id): output_save = "output.png" test_image = Image.open(image_path) model = SegformerForSemanticSegmentation.from_pretrained(model_id) proccessor = SegformerImageProcessor(model_id) inputs = proccessor(images=test_image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) result = proccessor.post_process_semantic_segmentation(outputs)[0] result = np.array(result) result = visualize_instance_seg_mask(result) plt.figure(figsize=(10, 10)) for plot_index in range(2): if plot_index == 0: plot_image = test_image title = "Original" else: plot_image = result title = "Segmentation" plt.subplot(1, 2, plot_index+1) plt.imshow(plot_image) plt.title(title) plt.axis("off") plt.savefig(output_save) return output_save inputs = [ gr.inputs.Image(type="filepath", label="Input Image"), gr.inputs.Dropdown( choices=[ "deprem-ml/deprem_satellite_semantic_whu" ], label="Model ID", default="deprem-ml/deprem_satellite_semantic_whu", ) ] outputs = gr.Image(type="filepath", label="Segmentation") examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"], [image_list[1], "deprem-ml/deprem_satellite_semantic_whu"], [image_list[2], "deprem-ml/deprem_satellite_semantic_whu"], [image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]] title = "Deprem ML - Segformer Semantic Segmentation" demo_app = gr.Interface( Segformer_Segmentation, inputs, outputs, examples=examples, title=title, cache_examples=True ) demo_app.launch(debug=True, enable_queue=True)